From e6918187568dbd01842d8d1d2c808ce16a894239 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Sun, 21 Apr 2024 13:54:28 +0200 Subject: Adding upstream version 18.2.2. Signed-off-by: Daniel Baumann --- src/arrow/java/algorithm/pom.xml | 55 +++ .../algorithm/deduplicate/DeduplicationUtils.java | 96 +++++ .../deduplicate/VectorRunDeduplicator.java | 108 ++++++ .../algorithm/dictionary/DictionaryBuilder.java | 72 ++++ .../algorithm/dictionary/DictionaryEncoder.java | 39 ++ .../HashTableBasedDictionaryBuilder.java | 153 ++++++++ .../dictionary/HashTableDictionaryEncoder.java | 146 +++++++ .../dictionary/LinearDictionaryEncoder.java | 112 ++++++ .../dictionary/SearchDictionaryEncoder.java | 100 +++++ .../SearchTreeBasedDictionaryBuilder.java | 146 +++++++ .../arrow/algorithm/misc/PartialSumUtils.java | 119 ++++++ .../apache/arrow/algorithm/rank/VectorRank.java | 89 +++++ .../arrow/algorithm/search/ParallelSearcher.java | 190 +++++++++ .../algorithm/search/VectorRangeSearcher.java | 108 ++++++ .../arrow/algorithm/search/VectorSearcher.java | 88 +++++ .../algorithm/sort/CompositeVectorComparator.java | 71 ++++ .../algorithm/sort/DefaultVectorComparators.java | 431 +++++++++++++++++++++ .../sort/FixedWidthInPlaceVectorSorter.java | 169 ++++++++ .../sort/FixedWidthOutOfPlaceVectorSorter.java | 80 ++++ .../arrow/algorithm/sort/InPlaceVectorSorter.java | 37 ++ .../apache/arrow/algorithm/sort/IndexSorter.java | 180 +++++++++ .../arrow/algorithm/sort/InsertionSorter.java | 74 ++++ .../arrow/algorithm/sort/OffHeapIntStack.java | 72 ++++ .../algorithm/sort/OutOfPlaceVectorSorter.java | 37 ++ .../algorithm/sort/StableVectorComparator.java | 66 ++++ .../sort/VariableWidthOutOfPlaceVectorSorter.java | 93 +++++ .../algorithm/sort/VectorValueComparator.java | 123 ++++++ .../deduplicate/TestDeduplicationUtils.java | 135 +++++++ .../deduplicate/TestVectorRunDeduplicator.java | 131 +++++++ .../TestHashTableBasedDictionaryBuilder.java | 202 ++++++++++ .../dictionary/TestHashTableDictionaryEncoder.java | 350 +++++++++++++++++ .../dictionary/TestLinearDictionaryEncoder.java | 350 +++++++++++++++++ .../dictionary/TestSearchDictionaryEncoder.java | 357 +++++++++++++++++ .../TestSearchTreeBasedDictionaryBuilder.java | 221 +++++++++++ .../arrow/algorithm/misc/TestPartialSumUtils.java | 138 +++++++ .../arrow/algorithm/rank/TestVectorRank.java | 145 +++++++ .../algorithm/search/TestParallelSearcher.java | 150 +++++++ .../algorithm/search/TestVectorRangeSearcher.java | 195 ++++++++++ .../arrow/algorithm/search/TestVectorSearcher.java | 299 ++++++++++++++ .../sort/TestCompositeVectorComparator.java | 112 ++++++ .../sort/TestDefaultVectorComparator.java | 393 +++++++++++++++++++ .../sort/TestFixedWidthInPlaceVectorSorter.java | 240 ++++++++++++ .../sort/TestFixedWidthOutOfPlaceVectorSorter.java | 365 +++++++++++++++++ .../algorithm/sort/TestFixedWidthSorting.java | 172 ++++++++ .../arrow/algorithm/sort/TestIndexSorter.java | 205 ++++++++++ .../arrow/algorithm/sort/TestInsertionSorter.java | 117 ++++++ .../arrow/algorithm/sort/TestOffHeapIntStack.java | 67 ++++ .../arrow/algorithm/sort/TestSortingUtil.java | 166 ++++++++ .../algorithm/sort/TestStableVectorComparator.java | 137 +++++++ .../TestVariableWidthOutOfPlaceVectorSorter.java | 99 +++++ .../algorithm/sort/TestVariableWidthSorting.java | 165 ++++++++ 51 files changed, 7965 insertions(+) create mode 100644 src/arrow/java/algorithm/pom.xml create mode 100644 src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/deduplicate/DeduplicationUtils.java create mode 100644 src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/deduplicate/VectorRunDeduplicator.java create mode 100644 src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/dictionary/DictionaryBuilder.java create mode 100644 src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/dictionary/DictionaryEncoder.java create mode 100644 src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/dictionary/HashTableBasedDictionaryBuilder.java create mode 100644 src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/dictionary/HashTableDictionaryEncoder.java create mode 100644 src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/dictionary/LinearDictionaryEncoder.java create mode 100644 src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/dictionary/SearchDictionaryEncoder.java create mode 100644 src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/dictionary/SearchTreeBasedDictionaryBuilder.java create mode 100644 src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/misc/PartialSumUtils.java create mode 100644 src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/rank/VectorRank.java create mode 100644 src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/search/ParallelSearcher.java create mode 100644 src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/search/VectorRangeSearcher.java create mode 100644 src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/search/VectorSearcher.java create mode 100644 src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/CompositeVectorComparator.java create mode 100644 src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/DefaultVectorComparators.java create mode 100644 src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/FixedWidthInPlaceVectorSorter.java create mode 100644 src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/FixedWidthOutOfPlaceVectorSorter.java create mode 100644 src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/InPlaceVectorSorter.java create mode 100644 src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/IndexSorter.java create mode 100644 src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/InsertionSorter.java create mode 100644 src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/OffHeapIntStack.java create mode 100644 src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/OutOfPlaceVectorSorter.java create mode 100644 src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/StableVectorComparator.java create mode 100644 src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/VariableWidthOutOfPlaceVectorSorter.java create mode 100644 src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/VectorValueComparator.java create mode 100644 src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/deduplicate/TestDeduplicationUtils.java create mode 100644 src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/deduplicate/TestVectorRunDeduplicator.java create mode 100644 src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/dictionary/TestHashTableBasedDictionaryBuilder.java create mode 100644 src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/dictionary/TestHashTableDictionaryEncoder.java create mode 100644 src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/dictionary/TestLinearDictionaryEncoder.java create mode 100644 src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/dictionary/TestSearchDictionaryEncoder.java create mode 100644 src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/dictionary/TestSearchTreeBasedDictionaryBuilder.java create mode 100644 src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/misc/TestPartialSumUtils.java create mode 100644 src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/rank/TestVectorRank.java create mode 100644 src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/search/TestParallelSearcher.java create mode 100644 src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/search/TestVectorRangeSearcher.java create mode 100644 src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/search/TestVectorSearcher.java create mode 100644 src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestCompositeVectorComparator.java create mode 100644 src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestDefaultVectorComparator.java create mode 100644 src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestFixedWidthInPlaceVectorSorter.java create mode 100644 src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestFixedWidthOutOfPlaceVectorSorter.java create mode 100644 src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestFixedWidthSorting.java create mode 100644 src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestIndexSorter.java create mode 100644 src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestInsertionSorter.java create mode 100644 src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestOffHeapIntStack.java create mode 100644 src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestSortingUtil.java create mode 100644 src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestStableVectorComparator.java create mode 100644 src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestVariableWidthOutOfPlaceVectorSorter.java create mode 100644 src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestVariableWidthSorting.java (limited to 'src/arrow/java/algorithm') diff --git a/src/arrow/java/algorithm/pom.xml b/src/arrow/java/algorithm/pom.xml new file mode 100644 index 000000000..fa4787d30 --- /dev/null +++ b/src/arrow/java/algorithm/pom.xml @@ -0,0 +1,55 @@ + + + + 4.0.0 + + org.apache.arrow + arrow-java-root + 6.0.1 + + arrow-algorithm + Arrow Algorithms + (Experimental/Contrib) A collection of algorithms for working with ValueVectors. + + + + org.apache.arrow + arrow-vector + ${project.version} + ${arrow.vector.classifier} + + + org.apache.arrow + arrow-vector + ${project.version} + test-jar + + + org.apache.arrow + arrow-memory-core + ${project.version} + + + org.apache.arrow + arrow-memory-netty + ${project.version} + test + + + io.netty + netty-common + + + + + + diff --git a/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/deduplicate/DeduplicationUtils.java b/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/deduplicate/DeduplicationUtils.java new file mode 100644 index 000000000..8811e43d3 --- /dev/null +++ b/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/deduplicate/DeduplicationUtils.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.algorithm.deduplicate; + +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.util.Preconditions; +import org.apache.arrow.vector.BitVectorHelper; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.compare.Range; +import org.apache.arrow.vector.compare.RangeEqualsVisitor; +import org.apache.arrow.vector.util.DataSizeRoundingUtil; + +/** + * Utilities for vector deduplication. + */ +class DeduplicationUtils { + + /** + * Gets the start positions of the first distinct values in a vector. + * @param vector the target vector. + * @param runStarts the bit set to hold the start positions. + * @param vector type. + */ + public static void populateRunStartIndicators(V vector, ArrowBuf runStarts) { + int bufSize = DataSizeRoundingUtil.divideBy8Ceil(vector.getValueCount()); + Preconditions.checkArgument(runStarts.capacity() >= bufSize); + runStarts.setZero(0, bufSize); + + BitVectorHelper.setBit(runStarts, 0); + RangeEqualsVisitor visitor = new RangeEqualsVisitor(vector, vector, null); + Range range = new Range(0, 0, 1); + for (int i = 1; i < vector.getValueCount(); i++) { + range.setLeftStart(i).setRightStart(i - 1); + if (!visitor.rangeEquals(range)) { + BitVectorHelper.setBit(runStarts, i); + } + } + } + + /** + * Gets the run lengths, given the start positions. + * @param runStarts the bit set for start positions. + * @param runLengths the run length vector to populate. + * @param valueCount the number of values in the bit set. + */ + public static void populateRunLengths(ArrowBuf runStarts, IntVector runLengths, int valueCount) { + int curStart = 0; + int lengthIndex = 0; + for (int i = 1; i < valueCount; i++) { + if (BitVectorHelper.get(runStarts, i) != 0) { + // we get a new distinct value + runLengths.setSafe(lengthIndex++, i - curStart); + curStart = i; + } + } + + // process the last value + runLengths.setSafe(lengthIndex++, valueCount - curStart); + runLengths.setValueCount(lengthIndex); + } + + /** + * Gets distinct values from the input vector by removing adjacent + * duplicated values. + * @param indicators the bit set containing the start positions of distinct values. + * @param inputVector the input vector. + * @param outputVector the output vector. + * @param vector type. + */ + public static void populateDeduplicatedValues( + ArrowBuf indicators, V inputVector, V outputVector) { + int dstIdx = 0; + for (int srcIdx = 0; srcIdx < inputVector.getValueCount(); srcIdx++) { + if (BitVectorHelper.get(indicators, srcIdx) != 0) { + outputVector.copyFromSafe(srcIdx, dstIdx++, inputVector); + } + } + outputVector.setValueCount(dstIdx); + } +} diff --git a/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/deduplicate/VectorRunDeduplicator.java b/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/deduplicate/VectorRunDeduplicator.java new file mode 100644 index 000000000..5ef03cbe4 --- /dev/null +++ b/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/deduplicate/VectorRunDeduplicator.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.algorithm.deduplicate; + +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.util.Preconditions; +import org.apache.arrow.vector.BitVectorHelper; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.util.DataSizeRoundingUtil; + +/** + * Remove adjacent equal elements from a vector. + * If the vector is sorted, it removes all duplicated values in the vector. + * @param vector type. + */ +public class VectorRunDeduplicator implements AutoCloseable { + + /** + * Bit set for distinct values. + * If the value at some index is not equal to the previous value, + * its bit is set to 1, otherwise its bit is set to 0. + */ + private ArrowBuf distinctValueBuffer; + + /** + * The vector to deduplicate. + */ + private final V vector; + + private final BufferAllocator allocator; + + /** + * Constructs a vector run deduplicator for a given vector. + * @param vector the vector to deduplicate. Ownership is NOT taken. + * @param allocator the allocator used for allocating buffers for start indices. + */ + public VectorRunDeduplicator(V vector, BufferAllocator allocator) { + this.vector = vector; + this.allocator = allocator; + } + + private void createDistinctValueBuffer() { + Preconditions.checkArgument(distinctValueBuffer == null); + int bufSize = DataSizeRoundingUtil.divideBy8Ceil(vector.getValueCount()); + distinctValueBuffer = allocator.buffer(bufSize); + DeduplicationUtils.populateRunStartIndicators(vector, distinctValueBuffer); + } + + /** + * Gets the number of values which are different from their predecessor. + * @return the run count. + */ + public int getRunCount() { + if (distinctValueBuffer == null) { + createDistinctValueBuffer(); + } + return vector.getValueCount() - BitVectorHelper.getNullCount(distinctValueBuffer, vector.getValueCount()); + } + + /** + * Gets the vector with deduplicated adjacent values removed. + * @param outVector the output vector. + */ + public void populateDeduplicatedValues(V outVector) { + if (distinctValueBuffer == null) { + createDistinctValueBuffer(); + } + + DeduplicationUtils.populateDeduplicatedValues(distinctValueBuffer, vector, outVector); + } + + /** + * Gets the length of each distinct value. + * @param lengthVector the vector for holding length values. + */ + public void populateRunLengths(IntVector lengthVector) { + if (distinctValueBuffer == null) { + createDistinctValueBuffer(); + } + + DeduplicationUtils.populateRunLengths(distinctValueBuffer, lengthVector, vector.getValueCount()); + } + + @Override + public void close() { + if (distinctValueBuffer != null) { + distinctValueBuffer.close(); + distinctValueBuffer = null; + } + } +} diff --git a/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/dictionary/DictionaryBuilder.java b/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/dictionary/DictionaryBuilder.java new file mode 100644 index 000000000..398368d1f --- /dev/null +++ b/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/dictionary/DictionaryBuilder.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.algorithm.dictionary; + +import org.apache.arrow.vector.ValueVector; + +/** + * A dictionary builder is intended for the scenario frequently encountered in practice: + * the dictionary is not known a priori, so it is generated dynamically. + * In particular, when a new value arrives, it is tested to check if it is already + * in the dictionary. If so, it is simply neglected, otherwise, it is added to the dictionary. + *

+ * The dictionary builder is intended to build a single dictionary. + * So it cannot be used for different dictionaries. + *

+ *

Below gives the sample code for using the dictionary builder + *

{@code
+ * DictionaryBuilder dictionaryBuilder = ...
+ * ...
+ * dictionaryBuild.addValue(newValue);
+ * ...
+ * }
+ *

+ *

+ * With the above code, the dictionary vector will be populated, + * and it can be retrieved by the {@link DictionaryBuilder#getDictionary()} method. + * After that, dictionary encoding can proceed with the populated dictionary.. + *

+ * + * @param the dictionary vector type. + */ +public interface DictionaryBuilder { + + /** + * Try to add all values from the target vector to the dictionary. + * + * @param targetVector the target vector containing values to probe. + * @return the number of values actually added to the dictionary. + */ + int addValues(V targetVector); + + /** + * Try to add an element from the target vector to the dictionary. + * + * @param targetVector the target vector containing new element. + * @param targetIndex the index of the new element in the target vector. + * @return the index of the new element in the dictionary. + */ + int addValue(V targetVector, int targetIndex); + + /** + * Gets the dictionary built. + * + * @return the dictionary. + */ + V getDictionary(); +} diff --git a/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/dictionary/DictionaryEncoder.java b/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/dictionary/DictionaryEncoder.java new file mode 100644 index 000000000..cda7b3bf9 --- /dev/null +++ b/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/dictionary/DictionaryEncoder.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.algorithm.dictionary; + +import org.apache.arrow.vector.BaseIntVector; +import org.apache.arrow.vector.ValueVector; + +/** + * A dictionary encoder translates one vector into another one based on a dictionary vector. + * According to Arrow specification, the encoded vector must be an integer based vector, which + * is the index of the original vector element in the dictionary. + * @param type of the encoded vector. + * @param type of the vector to encode. It is also the type of the dictionary vector. + */ +public interface DictionaryEncoder { + + /** + * Translates an input vector into an output vector. + * @param input the input vector. + * @param output the output vector. Note that it must be in a fresh state. At least, + * all its validity bits should be clear. + */ + void encode(D input, E output); +} diff --git a/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/dictionary/HashTableBasedDictionaryBuilder.java b/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/dictionary/HashTableBasedDictionaryBuilder.java new file mode 100644 index 000000000..dd2b73498 --- /dev/null +++ b/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/dictionary/HashTableBasedDictionaryBuilder.java @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.algorithm.dictionary; + +import java.util.HashMap; + +import org.apache.arrow.memory.util.ArrowBufPointer; +import org.apache.arrow.memory.util.hash.ArrowBufHasher; +import org.apache.arrow.memory.util.hash.SimpleHasher; +import org.apache.arrow.vector.ElementAddressableVector; + +/** + * This class builds the dictionary based on a hash table. + * Each add operation can be finished in O(1) time, + * where n is the current dictionary size. + * + * @param the dictionary vector type. + */ +public class HashTableBasedDictionaryBuilder implements DictionaryBuilder { + + /** + * The dictionary to be built. + */ + private final V dictionary; + + /** + * If null should be encoded. + */ + private final boolean encodeNull; + + /** + * The hash map for distinct dictionary entries. + * The key is the pointer to the dictionary element, whereas the value is the index in the dictionary. + */ + private HashMap hashMap = new HashMap<>(); + + /** + * The hasher used for calculating the hash code. + */ + private final ArrowBufHasher hasher; + + /** + * Next pointer to try to add to the hash table. + */ + private ArrowBufPointer nextPointer; + + /** + * Constructs a hash table based dictionary builder. + * + * @param dictionary the dictionary to populate. + */ + public HashTableBasedDictionaryBuilder(V dictionary) { + this(dictionary, false); + } + + /** + * Constructs a hash table based dictionary builder. + * + * @param dictionary the dictionary to populate. + * @param encodeNull if null values should be added to the dictionary. + */ + public HashTableBasedDictionaryBuilder(V dictionary, boolean encodeNull) { + this(dictionary, encodeNull, SimpleHasher.INSTANCE); + } + + /** + * Constructs a hash table based dictionary builder. + * + * @param dictionary the dictionary to populate. + * @param encodeNull if null values should be added to the dictionary. + * @param hasher the hasher used to compute the hash code. + */ + public HashTableBasedDictionaryBuilder(V dictionary, boolean encodeNull, ArrowBufHasher hasher) { + this.dictionary = dictionary; + this.encodeNull = encodeNull; + this.hasher = hasher; + this.nextPointer = new ArrowBufPointer(hasher); + } + + /** + * Gets the dictionary built. + * + * @return the dictionary. + */ + @Override + public V getDictionary() { + return dictionary; + } + + /** + * Try to add all values from the target vector to the dictionary. + * + * @param targetVector the target vector containing values to probe. + * @return the number of values actually added to the dictionary. + */ + @Override + public int addValues(V targetVector) { + int oldDictSize = dictionary.getValueCount(); + for (int i = 0; i < targetVector.getValueCount(); i++) { + if (!encodeNull && targetVector.isNull(i)) { + continue; + } + addValue(targetVector, i); + } + + return dictionary.getValueCount() - oldDictSize; + } + + /** + * Try to add an element from the target vector to the dictionary. + * + * @param targetVector the target vector containing new element. + * @param targetIndex the index of the new element in the target vector. + * @return the index of the new element in the dictionary. + */ + @Override + public int addValue(V targetVector, int targetIndex) { + targetVector.getDataPointer(targetIndex, nextPointer); + + Integer index = hashMap.get(nextPointer); + if (index == null) { + // a new dictionary element is found + + // insert it to the dictionary + int dictSize = dictionary.getValueCount(); + dictionary.copyFromSafe(targetIndex, dictSize, targetVector); + dictionary.setValueCount(dictSize + 1); + dictionary.getDataPointer(dictSize, nextPointer); + + // insert it to the hash map + hashMap.put(nextPointer, dictSize); + nextPointer = new ArrowBufPointer(hasher); + + return dictSize; + } + return index; + } +} diff --git a/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/dictionary/HashTableDictionaryEncoder.java b/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/dictionary/HashTableDictionaryEncoder.java new file mode 100644 index 000000000..bea1a784c --- /dev/null +++ b/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/dictionary/HashTableDictionaryEncoder.java @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.algorithm.dictionary; + +import java.util.HashMap; + +import org.apache.arrow.memory.util.ArrowBufPointer; +import org.apache.arrow.memory.util.hash.ArrowBufHasher; +import org.apache.arrow.memory.util.hash.SimpleHasher; +import org.apache.arrow.vector.BaseIntVector; +import org.apache.arrow.vector.ElementAddressableVector; + +/** + * Dictionary encoder based on hash table. + * @param encoded vector type. + * @param decoded vector type, which is also the dictionary type. + */ +public class HashTableDictionaryEncoder + implements DictionaryEncoder { + + /** + * The dictionary for encoding/decoding. + * It must be sorted. + */ + private final D dictionary; + + /** + * The hasher used to compute the hash code. + */ + private final ArrowBufHasher hasher; + + /** + * A flag indicating if null should be encoded. + */ + private final boolean encodeNull; + + /** + * The hash map for distinct dictionary entries. + * The key is the pointer to the dictionary element, whereas the value is the index in the dictionary. + */ + private HashMap hashMap = new HashMap<>(); + + /** + * The pointer used to probe each element to encode. + */ + private ArrowBufPointer reusablePointer; + + /** + * Constructs a dictionary encoder. + * @param dictionary the dictionary. + * + */ + public HashTableDictionaryEncoder(D dictionary) { + this(dictionary, false); + } + + /** + * Constructs a dictionary encoder. + * @param dictionary the dictionary. + * @param encodeNull a flag indicating if null should be encoded. + * It determines the behaviors for processing null values in the input during encoding/decoding. + *
  • + * For encoding, when a null is encountered in the input, + * 1) If the flag is set to true, the encoder searches for the value in the dictionary, + * and outputs the index in the dictionary. + * 2) If the flag is set to false, the encoder simply produces a null in the output. + *
  • + *
  • + * For decoding, when a null is encountered in the input, + * 1) If the flag is set to true, the decoder should never expect a null in the input. + * 2) If set to false, the decoder simply produces a null in the output. + *
  • + */ + public HashTableDictionaryEncoder(D dictionary, boolean encodeNull) { + this(dictionary, encodeNull, SimpleHasher.INSTANCE); + } + + /** + * Constructs a dictionary encoder. + * @param dictionary the dictionary. + * @param encodeNull a flag indicating if null should be encoded. + * It determines the behaviors for processing null values in the input during encoding. + * When a null is encountered in the input, + * 1) If the flag is set to true, the encoder searches for the value in the dictionary, + * and outputs the index in the dictionary. + * 2) If the flag is set to false, the encoder simply produces a null in the output. + * @param hasher the hasher used to calculate the hash code. + */ + public HashTableDictionaryEncoder(D dictionary, boolean encodeNull, ArrowBufHasher hasher) { + this.dictionary = dictionary; + this.hasher = hasher; + this.encodeNull = encodeNull; + + reusablePointer = new ArrowBufPointer(hasher); + + buildHashMap(); + } + + private void buildHashMap() { + for (int i = 0; i < dictionary.getValueCount(); i++) { + ArrowBufPointer pointer = new ArrowBufPointer(hasher); + dictionary.getDataPointer(i, pointer); + hashMap.put(pointer, i); + } + } + + /** + * Encodes an input vector by a hash table. + * So the algorithm takes O(n) time, where n is the length of the input vector. + * + * @param input the input vector. + * @param output the output vector. + **/ + @Override + public void encode(D input, E output) { + for (int i = 0; i < input.getValueCount(); i++) { + if (!encodeNull && input.isNull(i)) { + continue; + } + + input.getDataPointer(i, reusablePointer); + Integer index = hashMap.get(reusablePointer); + + if (index == null) { + throw new IllegalArgumentException("The data element is not found in the dictionary"); + } + output.setWithPossibleTruncate(i, index); + } + output.setValueCount(input.getValueCount()); + } +} diff --git a/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/dictionary/LinearDictionaryEncoder.java b/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/dictionary/LinearDictionaryEncoder.java new file mode 100644 index 000000000..84a3a96af --- /dev/null +++ b/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/dictionary/LinearDictionaryEncoder.java @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.algorithm.dictionary; + +import org.apache.arrow.vector.BaseIntVector; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.compare.Range; +import org.apache.arrow.vector.compare.RangeEqualsVisitor; + +/** + * Dictionary encoder based on linear search. + * @param encoded vector type. + * @param decoded vector type, which is also the dictionary type. + */ +public class LinearDictionaryEncoder + implements DictionaryEncoder { + + /** + * The dictionary for encoding. + */ + private final D dictionary; + + /** + * A flag indicating if null should be encoded. + */ + private final boolean encodeNull; + + private RangeEqualsVisitor equalizer; + + private Range range; + + /** + * Constructs a dictionary encoder, with the encode null flag set to false. + * @param dictionary the dictionary. Its entries should be sorted in the non-increasing order of their frequency. + * Otherwise, the encoder still produces correct results, but at the expense of performance overhead. + */ + public LinearDictionaryEncoder(D dictionary) { + this(dictionary, false); + } + + /** + * Constructs a dictionary encoder. + * @param dictionary the dictionary. Its entries should be sorted in the non-increasing order of their frequency. + * Otherwise, the encoder still produces correct results, but at the expense of performance overhead. + * @param encodeNull a flag indicating if null should be encoded. + * It determines the behaviors for processing null values in the input during encoding. + * When a null is encountered in the input, + * 1) If the flag is set to true, the encoder searches for the value in the dictionary, + * and outputs the index in the dictionary. + * 2) If the flag is set to false, the encoder simply produces a null in the output. + */ + public LinearDictionaryEncoder(D dictionary, boolean encodeNull) { + this.dictionary = dictionary; + this.encodeNull = encodeNull; + + // temporarily set left and right vectors to dictionary + equalizer = new RangeEqualsVisitor(dictionary, dictionary, null); + range = new Range(0, 0, 1); + } + + /** + * Encodes an input vector by linear search. + * When the dictionary is sorted in the non-increasing order of the entry frequency, + * it will have constant time complexity, with no extra memory requirement. + * @param input the input vector. + * @param output the output vector. Note that it must be in a fresh state. At least, + * all its validity bits should be clear. + */ + @Override + public void encode(D input, E output) { + for (int i = 0; i < input.getValueCount(); i++) { + if (!encodeNull && input.isNull(i)) { + // for this case, we should simply output a null in the output. + // by assuming the output vector is fresh, we do nothing here. + continue; + } + + int index = linearSearch(input, i); + if (index == -1) { + throw new IllegalArgumentException("The data element is not found in the dictionary: " + i); + } + output.setWithPossibleTruncate(i, index); + } + output.setValueCount(input.getValueCount()); + } + + private int linearSearch(D input, int index) { + range.setLeftStart(index); + for (int i = 0; i < dictionary.getValueCount(); i++) { + range.setRightStart(i); + if (input.accept(equalizer, range)) { + return i; + } + } + return -1; + } +} diff --git a/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/dictionary/SearchDictionaryEncoder.java b/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/dictionary/SearchDictionaryEncoder.java new file mode 100644 index 000000000..1dbf65819 --- /dev/null +++ b/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/dictionary/SearchDictionaryEncoder.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.algorithm.dictionary; + +import org.apache.arrow.algorithm.search.VectorSearcher; +import org.apache.arrow.algorithm.sort.VectorValueComparator; +import org.apache.arrow.vector.BaseIntVector; +import org.apache.arrow.vector.ValueVector; + +/** + * Dictionary encoder based on searching. + * @param encoded vector type. + * @param decoded vector type, which is also the dictionary type. + */ +public class SearchDictionaryEncoder + implements DictionaryEncoder { + + /** + * The dictionary for encoding/decoding. + * It must be sorted. + */ + private final D dictionary; + + /** + * The criteria by which the dictionary is sorted. + */ + private final VectorValueComparator comparator; + + /** + * A flag indicating if null should be encoded. + */ + private final boolean encodeNull; + + /** + * Constructs a dictionary encoder. + * @param dictionary the dictionary. It must be in sorted order. + * @param comparator the criteria for sorting. + */ + public SearchDictionaryEncoder(D dictionary, VectorValueComparator comparator) { + this(dictionary, comparator, false); + } + + /** + * Constructs a dictionary encoder. + * @param dictionary the dictionary. It must be in sorted order. + * @param comparator the criteria for sorting. + * @param encodeNull a flag indicating if null should be encoded. + * It determines the behaviors for processing null values in the input during encoding. + * When a null is encountered in the input, + * 1) If the flag is set to true, the encoder searches for the value in the dictionary, + * and outputs the index in the dictionary. + * 2) If the flag is set to false, the encoder simply produces a null in the output. + */ + public SearchDictionaryEncoder(D dictionary, VectorValueComparator comparator, boolean encodeNull) { + this.dictionary = dictionary; + this.comparator = comparator; + this.encodeNull = encodeNull; + } + + /** + * Encodes an input vector by binary search. + * So the algorithm takes O(n * log(m)) time, where n is the length of the input vector, + * and m is the length of the dictionary. + * @param input the input vector. + * @param output the output vector. Note that it must be in a fresh state. At least, + * all its validity bits should be clear. + */ + @Override + public void encode(D input, E output) { + for (int i = 0; i < input.getValueCount(); i++) { + if (!encodeNull && input.isNull(i)) { + // for this case, we should simply output a null in the output. + // by assuming the output vector is fresh, we do nothing here. + continue; + } + + int index = VectorSearcher.binarySearch(dictionary, comparator, input, i); + if (index == -1) { + throw new IllegalArgumentException("The data element is not found in the dictionary: " + i); + } + output.setWithPossibleTruncate(i, index); + } + output.setValueCount(input.getValueCount()); + } +} diff --git a/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/dictionary/SearchTreeBasedDictionaryBuilder.java b/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/dictionary/SearchTreeBasedDictionaryBuilder.java new file mode 100644 index 000000000..f9cd77daa --- /dev/null +++ b/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/dictionary/SearchTreeBasedDictionaryBuilder.java @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.algorithm.dictionary; + +import java.util.TreeSet; + +import org.apache.arrow.algorithm.sort.VectorValueComparator; +import org.apache.arrow.vector.ValueVector; + +/** + * This class builds the dictionary based on a binary search tree. + * Each add operation can be finished in O(log(n)) time, + * where n is the current dictionary size. + * + * @param the dictionary vector type. + */ +public class SearchTreeBasedDictionaryBuilder implements DictionaryBuilder { + + /** + * The dictionary to be built. + */ + private final V dictionary; + + /** + * The criteria for sorting in the search tree. + */ + protected final VectorValueComparator comparator; + + /** + * If null should be encoded. + */ + private final boolean encodeNull; + + /** + * The search tree for storing the value index. + */ + private TreeSet searchTree; + + /** + * Construct a search tree-based dictionary builder. + * @param dictionary the dictionary vector. + * @param comparator the criteria for value equality. + */ + public SearchTreeBasedDictionaryBuilder(V dictionary, VectorValueComparator comparator) { + this(dictionary, comparator, false); + } + + /** + * Construct a search tree-based dictionary builder. + * @param dictionary the dictionary vector. + * @param comparator the criteria for value equality. + * @param encodeNull if null values should be added to the dictionary. + */ + public SearchTreeBasedDictionaryBuilder(V dictionary, VectorValueComparator comparator, boolean encodeNull) { + this.dictionary = dictionary; + this.comparator = comparator; + this.encodeNull = encodeNull; + this.comparator.attachVector(dictionary); + + searchTree = new TreeSet<>((index1, index2) -> comparator.compare(index1, index2)); + } + + /** + * Gets the dictionary built. + * Please note that the dictionary is not in sorted order. + * Instead, its order is determined by the order of element insertion. + * To get the dictionary in sorted order, please use + * {@link SearchTreeBasedDictionaryBuilder#populateSortedDictionary(ValueVector)}. + * @return the dictionary. + */ + @Override + public V getDictionary() { + return dictionary; + } + + /** + * Try to add all values from the target vector to the dictionary. + * @param targetVector the target vector containing values to probe. + * @return the number of values actually added to the dictionary. + */ + @Override + public int addValues(V targetVector) { + int oldDictSize = dictionary.getValueCount(); + for (int i = 0; i < targetVector.getValueCount(); i++) { + if (!encodeNull && targetVector.isNull(i)) { + continue; + } + addValue(targetVector, i); + } + return dictionary.getValueCount() - oldDictSize; + } + + /** + * Try to add an element from the target vector to the dictionary. + * @param targetVector the target vector containing new element. + * @param targetIndex the index of the new element in the target vector. + * @return the index of the new element in the dictionary. + */ + @Override + public int addValue(V targetVector, int targetIndex) { + // first copy the value to the end of the dictionary + int dictSize = dictionary.getValueCount(); + dictionary.copyFromSafe(targetIndex, dictSize, targetVector); + + // try to add the value to the dictionary, + // if an equal element does not exist. + // this operation can be done in O(log(n)) time. + if (searchTree.add(dictSize)) { + // the element is successfully added + dictionary.setValueCount(dictSize + 1); + return dictSize; + } else { + // the element is already in the dictionary + // find its index in O(log(n)) time. + return searchTree.ceiling(dictSize); + } + } + + /** + * Gets the sorted dictionary. + * Note that given the binary search tree, the sort can finish in O(n). + */ + public void populateSortedDictionary(V sortedDictionary) { + int idx = 0; + for (Integer dictIdx : searchTree) { + sortedDictionary.copyFromSafe(dictIdx, idx++, dictionary); + } + + sortedDictionary.setValueCount(dictionary.getValueCount()); + } +} diff --git a/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/misc/PartialSumUtils.java b/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/misc/PartialSumUtils.java new file mode 100644 index 000000000..f5e95cf10 --- /dev/null +++ b/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/misc/PartialSumUtils.java @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.algorithm.misc; + +import org.apache.arrow.vector.BaseIntVector; + +/** + * Partial sum related utilities. + */ +public class PartialSumUtils { + + /** + * Converts an input vector to a partial sum vector. + * This is an inverse operation of {@link PartialSumUtils#toDeltaVector(BaseIntVector, BaseIntVector)}. + * Suppose we have input vector a and output vector b. + * Then we have b(0) = sumBase; b(i + 1) = b(i) + a(i) (i = 0, 1, 2, ...). + * @param deltaVector the input vector. + * @param partialSumVector the output vector. + * @param sumBase the base of the partial sums. + */ + public static void toPartialSumVector(BaseIntVector deltaVector, BaseIntVector partialSumVector, long sumBase) { + long sum = sumBase; + partialSumVector.setWithPossibleTruncate(0, sumBase); + + for (int i = 0; i < deltaVector.getValueCount(); i++) { + sum += deltaVector.getValueAsLong(i); + partialSumVector.setWithPossibleTruncate(i + 1, sum); + } + partialSumVector.setValueCount(deltaVector.getValueCount() + 1); + } + + /** + * Converts an input vector to the delta vector. + * This is an inverse operation of {@link PartialSumUtils#toPartialSumVector(BaseIntVector, BaseIntVector, long)}. + * Suppose we have input vector a and output vector b. + * Then we have b(i) = a(i + 1) - a(i) (i = 0, 1, 2, ...). + * @param partialSumVector the input vector. + * @param deltaVector the output vector. + */ + public static void toDeltaVector(BaseIntVector partialSumVector, BaseIntVector deltaVector) { + for (int i = 0; i < partialSumVector.getValueCount() - 1; i++) { + long delta = partialSumVector.getValueAsLong(i + 1) - partialSumVector.getValueAsLong(i); + deltaVector.setWithPossibleTruncate(i, delta); + } + deltaVector.setValueCount(partialSumVector.getValueCount() - 1); + } + + /** + * Given a value and a partial sum vector, finds its position in the partial sum vector. + * In particular, given an integer value a and partial sum vector v, we try to find a + * position i, so that v(i) <= a < v(i + 1). + * The algorithm is based on binary search, so it takes O(log(n)) time, where n is + * the length of the partial sum vector. + * @param partialSumVector the input partial sum vector. + * @param value the value to search. + * @return the position in the partial sum vector, if any, or -1, if none is found. + */ + public static int findPositionInPartialSumVector(BaseIntVector partialSumVector, long value) { + if (value < partialSumVector.getValueAsLong(0) || + value >= partialSumVector.getValueAsLong(partialSumVector.getValueCount() - 1)) { + return -1; + } + + int low = 0; + int high = partialSumVector.getValueCount() - 1; + while (low <= high) { + int mid = low + (high - low) / 2; + long midValue = partialSumVector.getValueAsLong(mid); + + if (midValue <= value) { + if (mid == partialSumVector.getValueCount() - 1) { + // the mid is the last element, we have found it + return mid; + } + long nextMidValue = partialSumVector.getValueAsLong(mid + 1); + if (value < nextMidValue) { + // midValue <= value < nextMidValue + // this is exactly what we want. + return mid; + } else { + // value >= nextMidValue + // continue to search from the next value on the right + low = mid + 1; + } + } else { + // midValue > value + long prevMidValue = partialSumVector.getValueAsLong(mid - 1); + if (prevMidValue <= value) { + // prevMidValue <= value < midValue + // this is exactly what we want + return mid - 1; + } else { + // prevMidValue > value + // continue to search from the previous value on the left + high = mid - 1; + } + } + } + throw new IllegalStateException("Should never get here"); + } + + private PartialSumUtils() { + } +} diff --git a/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/rank/VectorRank.java b/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/rank/VectorRank.java new file mode 100644 index 000000000..43c9a5b01 --- /dev/null +++ b/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/rank/VectorRank.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.algorithm.rank; + +import java.util.stream.IntStream; + +import org.apache.arrow.algorithm.sort.IndexSorter; +import org.apache.arrow.algorithm.sort.VectorValueComparator; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.util.Preconditions; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.ValueVector; + +/** + * Utility for calculating ranks of vector elements. + * @param the vector type + */ +public class VectorRank { + + private VectorValueComparator comparator; + + /** + * Vector indices. + */ + private IntVector indices; + + private final BufferAllocator allocator; + + /** + * Constructs a vector rank utility. + * @param allocator the allocator to use. + */ + public VectorRank(BufferAllocator allocator) { + this.allocator = allocator; + } + + /** + * Given a rank r, gets the index of the element that is the rth smallest in the vector. + * The operation is performed without changing the vector, and takes O(n) time, + * where n is the length of the vector. + * @param vector the vector from which to get the element index. + * @param comparator the criteria for vector element comparison. + * @param rank the rank to determine. + * @return the element index with the given rank. + */ + public int indexAtRank(V vector, VectorValueComparator comparator, int rank) { + Preconditions.checkArgument(rank >= 0 && rank < vector.getValueCount()); + try { + indices = new IntVector("index vector", allocator); + indices.allocateNew(vector.getValueCount()); + IntStream.range(0, vector.getValueCount()).forEach(i -> indices.set(i, i)); + + comparator.attachVector(vector); + this.comparator = comparator; + + int pos = getRank(0, vector.getValueCount() - 1, rank); + return indices.get(pos); + } finally { + indices.close(); + } + } + + private int getRank(int low, int high, int rank) { + int mid = IndexSorter.partition(low, high, indices, comparator); + if (mid < rank) { + return getRank(mid + 1, high, rank); + } else if (mid > rank) { + return getRank(low, mid - 1, rank); + } else { + // mid == rank + return mid; + } + } +} diff --git a/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/search/ParallelSearcher.java b/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/search/ParallelSearcher.java new file mode 100644 index 000000000..e93eb2c3d --- /dev/null +++ b/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/search/ParallelSearcher.java @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.algorithm.search; + +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; + +import org.apache.arrow.algorithm.sort.VectorValueComparator; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.compare.Range; +import org.apache.arrow.vector.compare.RangeEqualsVisitor; + +/** + * Search for a value in the vector by multiple threads. + * This is often used in scenarios where the vector is large or + * low response time is required. + * @param the vector type. + */ +public class ParallelSearcher { + + /** + * The target vector to search. + */ + private final V vector; + + /** + * The thread pool. + */ + private final ExecutorService threadPool; + + /** + * The number of threads to use. + */ + private final int numThreads; + + /** + * The position of the key in the target vector, if any. + */ + private int keyPosition = -1; + + /** + * Constructs a parallel searcher. + * @param vector the vector to search. + * @param threadPool the thread pool to use. + * @param numThreads the number of threads to use. + */ + public ParallelSearcher(V vector, ExecutorService threadPool, int numThreads) { + this.vector = vector; + this.threadPool = threadPool; + this.numThreads = numThreads; + } + + private CompletableFuture[] initSearch() { + keyPosition = -1; + final CompletableFuture[] futures = new CompletableFuture[numThreads]; + for (int i = 0; i < futures.length; i++) { + futures[i] = new CompletableFuture<>(); + } + return futures; + } + + /** + * Search for the key in the target vector. The element-wise comparison is based on + * {@link RangeEqualsVisitor}, so there are two possible results for each element-wise + * comparison: equal and un-equal. + * @param keyVector the vector containing the search key. + * @param keyIndex the index of the search key in the key vector. + * @return the position of a matched value in the target vector, + * or -1 if none is found. Please note that if there are multiple + * matches of the key in the target vector, this method makes no + * guarantees about which instance is returned. + * For an alternative search implementation that always finds the first match of the key, + * see {@link VectorSearcher#linearSearch(ValueVector, VectorValueComparator, ValueVector, int)}. + * @throws ExecutionException if an exception occurs in a thread. + * @throws InterruptedException if a thread is interrupted. + */ + public int search(V keyVector, int keyIndex) throws ExecutionException, InterruptedException { + final CompletableFuture[] futures = initSearch(); + final int valueCount = vector.getValueCount(); + for (int i = 0; i < numThreads; i++) { + final int tid = i; + threadPool.submit(() -> { + // convert to long to avoid overflow + int start = (int) (((long) valueCount) * tid / numThreads); + int end = (int) ((long) valueCount) * (tid + 1) / numThreads; + + if (start >= end) { + // no data assigned to this task. + futures[tid].complete(false); + return; + } + + RangeEqualsVisitor visitor = new RangeEqualsVisitor(vector, keyVector, null); + Range range = new Range(0, 0, 1); + for (int pos = start; pos < end; pos++) { + if (keyPosition != -1) { + // the key has been found by another task + futures[tid].complete(false); + return; + } + range.setLeftStart(pos).setRightStart(keyIndex); + if (visitor.rangeEquals(range)) { + keyPosition = pos; + futures[tid].complete(true); + return; + } + } + + // no match value is found. + futures[tid].complete(false); + }); + } + + CompletableFuture.allOf(futures).get(); + return keyPosition; + } + + /** + * Search for the key in the target vector. The element-wise comparison is based on + * {@link VectorValueComparator}, so there are three possible results for each element-wise + * comparison: less than, equal to and greater than. + * @param keyVector the vector containing the search key. + * @param keyIndex the index of the search key in the key vector. + * @param comparator the comparator for comparing the key against vector elements. + * @return the position of a matched value in the target vector, + * or -1 if none is found. Please note that if there are multiple + * matches of the key in the target vector, this method makes no + * guarantees about which instance is returned. + * For an alternative search implementation that always finds the first match of the key, + * see {@link VectorSearcher#linearSearch(ValueVector, VectorValueComparator, ValueVector, int)}. + * @throws ExecutionException if an exception occurs in a thread. + * @throws InterruptedException if a thread is interrupted. + */ + public int search( + V keyVector, int keyIndex, VectorValueComparator comparator) throws ExecutionException, InterruptedException { + final CompletableFuture[] futures = initSearch(); + final int valueCount = vector.getValueCount(); + for (int i = 0; i < numThreads; i++) { + final int tid = i; + threadPool.submit(() -> { + // convert to long to avoid overflow + int start = (int) (((long) valueCount) * tid / numThreads); + int end = (int) ((long) valueCount) * (tid + 1) / numThreads; + + if (start >= end) { + // no data assigned to this task. + futures[tid].complete(false); + return; + } + + VectorValueComparator localComparator = comparator.createNew(); + localComparator.attachVectors(vector, keyVector); + for (int pos = start; pos < end; pos++) { + if (keyPosition != -1) { + // the key has been found by another task + futures[tid].complete(false); + return; + } + if (localComparator.compare(pos, keyIndex) == 0) { + keyPosition = pos; + futures[tid].complete(true); + return; + } + } + + // no match value is found. + futures[tid].complete(false); + }); + } + + CompletableFuture.allOf(futures).get(); + return keyPosition; + } +} diff --git a/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/search/VectorRangeSearcher.java b/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/search/VectorRangeSearcher.java new file mode 100644 index 000000000..249194843 --- /dev/null +++ b/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/search/VectorRangeSearcher.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.algorithm.search; + +import org.apache.arrow.algorithm.sort.VectorValueComparator; +import org.apache.arrow.vector.ValueVector; + +/** + * Search for the range of a particular element in the target vector. + */ +public class VectorRangeSearcher { + + /** + * Result returned when a search fails. + */ + public static final int SEARCH_FAIL_RESULT = -1; + + /** + * Search for the first occurrence of an element. + * The search is based on the binary search algorithm. So the target vector must be sorted. + * @param targetVector the vector from which to perform the search. + * @param comparator the criterion for the comparison. + * @param keyVector the vector containing the element to search. + * @param keyIndex the index of the search key in the key vector. + * @param the vector type. + * @return the index of the first matched element if any, and -1 otherwise. + */ + public static int getFirstMatch( + V targetVector, VectorValueComparator comparator, V keyVector, int keyIndex) { + comparator.attachVectors(keyVector, targetVector); + + int ret = SEARCH_FAIL_RESULT; + + int low = 0; + int high = targetVector.getValueCount() - 1; + + while (low <= high) { + int mid = low + (high - low) / 2; + int result = comparator.compare(keyIndex, mid); + if (result < 0) { + // the key is smaller + high = mid - 1; + } else if (result > 0) { + // the key is larger + low = mid + 1; + } else { + // an equal element is found + // continue to go left-ward + ret = mid; + high = mid - 1; + } + } + return ret; + } + + /** + * Search for the last occurrence of an element. + * The search is based on the binary search algorithm. So the target vector must be sorted. + * @param targetVector the vector from which to perform the search. + * @param comparator the criterion for the comparison. + * @param keyVector the vector containing the element to search. + * @param keyIndex the index of the search key in the key vector. + * @param the vector type. + * @return the index of the last matched element if any, and -1 otherwise. + */ + public static int getLastMatch( + V targetVector, VectorValueComparator comparator, V keyVector, int keyIndex) { + comparator.attachVectors(keyVector, targetVector); + + int ret = SEARCH_FAIL_RESULT; + + int low = 0; + int high = targetVector.getValueCount() - 1; + + while (low <= high) { + int mid = low + (high - low) / 2; + int result = comparator.compare(keyIndex, mid); + if (result < 0) { + // the key is smaller + high = mid - 1; + } else if (result > 0) { + // the key is larger + low = mid + 1; + } else { + // an equal element is found, + // continue to go right-ward + ret = mid; + low = mid + 1; + } + } + return ret; + } +} diff --git a/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/search/VectorSearcher.java b/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/search/VectorSearcher.java new file mode 100644 index 000000000..646bca01b --- /dev/null +++ b/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/search/VectorSearcher.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.algorithm.search; + +import org.apache.arrow.algorithm.sort.VectorValueComparator; +import org.apache.arrow.vector.ValueVector; + +/** + * Search for a particular element in the vector. + */ +public final class VectorSearcher { + + /** + * Result returned when a search fails. + */ + public static final int SEARCH_FAIL_RESULT = -1; + + /** + * Search for a particular element from the key vector in the target vector by binary search. + * The target vector must be sorted. + * @param targetVector the vector from which to perform the sort. + * @param comparator the criterion for the sort. + * @param keyVector the vector containing the element to search. + * @param keyIndex the index of the search key in the key vector. + * @param the vector type. + * @return the index of a matched element if any, and -1 otherwise. + */ + public static int binarySearch( + V targetVector, VectorValueComparator comparator, V keyVector, int keyIndex) { + comparator.attachVectors(keyVector, targetVector); + + // perform binary search + int low = 0; + int high = targetVector.getValueCount() - 1; + + while (low <= high) { + int mid = low + (high - low) / 2; + int cmp = comparator.compare(keyIndex, mid); + if (cmp < 0) { + high = mid - 1; + } else if (cmp > 0) { + low = mid + 1; + } else { + return mid; + } + } + return SEARCH_FAIL_RESULT; + } + + /** + * Search for a particular element from the key vector in the target vector by traversing the vector in sequence. + * @param targetVector the vector from which to perform the search. + * @param comparator the criterion for element equality. + * @param keyVector the vector containing the element to search. + * @param keyIndex the index of the search key in the key vector. + * @param the vector type. + * @return the index of a matched element if any, and -1 otherwise. + */ + public static int linearSearch( + V targetVector, VectorValueComparator comparator, V keyVector, int keyIndex) { + comparator.attachVectors(keyVector, targetVector); + for (int i = 0; i < targetVector.getValueCount(); i++) { + if (comparator.compare(keyIndex, i) == 0) { + return i; + } + } + return SEARCH_FAIL_RESULT; + } + + private VectorSearcher() { + + } +} diff --git a/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/CompositeVectorComparator.java b/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/CompositeVectorComparator.java new file mode 100644 index 000000000..ec74598e0 --- /dev/null +++ b/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/CompositeVectorComparator.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.algorithm.sort; + +import org.apache.arrow.vector.ValueVector; + +/** + * A composite vector comparator compares a number of vectors + * by a number of inner comparators. + *

    + * It works by first using the first comparator, if a non-zero value + * is returned, it simply returns it. Otherwise, it uses the second comparator, + * and so on, until a non-zero value is produced, or all inner comparators have + * been used. + *

    + */ +public class CompositeVectorComparator extends VectorValueComparator { + + private final VectorValueComparator[] innerComparators; + + public CompositeVectorComparator(VectorValueComparator[] innerComparators) { + this.innerComparators = innerComparators; + } + + @Override + public int compareNotNull(int index1, int index2) { + // short-cut for scenarios when the caller can be sure that the vectors are non-nullable. + for (int i = 0; i < innerComparators.length; i++) { + int result = innerComparators[i].compareNotNull(index1, index2); + if (result != 0) { + return result; + } + } + return 0; + } + + @Override + public int compare(int index1, int index2) { + for (int i = 0; i < innerComparators.length; i++) { + int result = innerComparators[i].compare(index1, index2); + if (result != 0) { + return result; + } + } + return 0; + } + + @Override + public VectorValueComparator createNew() { + VectorValueComparator[] newInnerComparators = new VectorValueComparator[innerComparators.length]; + for (int i = 0; i < innerComparators.length; i++) { + newInnerComparators[i] = innerComparators[i].createNew(); + } + return new CompositeVectorComparator(newInnerComparators); + } +} diff --git a/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/DefaultVectorComparators.java b/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/DefaultVectorComparators.java new file mode 100644 index 000000000..c41821917 --- /dev/null +++ b/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/DefaultVectorComparators.java @@ -0,0 +1,431 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.algorithm.sort; + +import static org.apache.arrow.vector.complex.BaseRepeatedValueVector.OFFSET_WIDTH; + +import org.apache.arrow.memory.util.ArrowBufPointer; +import org.apache.arrow.memory.util.ByteFunctionHelpers; +import org.apache.arrow.vector.BaseFixedWidthVector; +import org.apache.arrow.vector.BaseVariableWidthVector; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.Float4Vector; +import org.apache.arrow.vector.Float8Vector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.SmallIntVector; +import org.apache.arrow.vector.TinyIntVector; +import org.apache.arrow.vector.UInt1Vector; +import org.apache.arrow.vector.UInt2Vector; +import org.apache.arrow.vector.UInt4Vector; +import org.apache.arrow.vector.UInt8Vector; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.complex.BaseRepeatedValueVector; + +/** + * Default comparator implementations for different types of vectors. + */ +public class DefaultVectorComparators { + + /** + * Create the default comparator for the vector. + * @param vector the vector. + * @param the vector type. + * @return the default comparator. + */ + public static VectorValueComparator createDefaultComparator(T vector) { + if (vector instanceof BaseFixedWidthVector) { + if (vector instanceof TinyIntVector) { + return (VectorValueComparator) new ByteComparator(); + } else if (vector instanceof SmallIntVector) { + return (VectorValueComparator) new ShortComparator(); + } else if (vector instanceof IntVector) { + return (VectorValueComparator) new IntComparator(); + } else if (vector instanceof BigIntVector) { + return (VectorValueComparator) new LongComparator(); + } else if (vector instanceof Float4Vector) { + return (VectorValueComparator) new Float4Comparator(); + } else if (vector instanceof Float8Vector) { + return (VectorValueComparator) new Float8Comparator(); + } else if (vector instanceof UInt1Vector) { + return (VectorValueComparator) new UInt1Comparator(); + } else if (vector instanceof UInt2Vector) { + return (VectorValueComparator) new UInt2Comparator(); + } else if (vector instanceof UInt4Vector) { + return (VectorValueComparator) new UInt4Comparator(); + } else if (vector instanceof UInt8Vector) { + return (VectorValueComparator) new UInt8Comparator(); + } + } else if (vector instanceof BaseVariableWidthVector) { + return (VectorValueComparator) new VariableWidthComparator(); + } else if (vector instanceof BaseRepeatedValueVector) { + VectorValueComparator innerComparator = + createDefaultComparator(((BaseRepeatedValueVector) vector).getDataVector()); + return new RepeatedValueComparator(innerComparator); + } + + throw new IllegalArgumentException("No default comparator for " + vector.getClass().getCanonicalName()); + } + + /** + * Default comparator for bytes. + * The comparison is based on values, with null comes first. + */ + public static class ByteComparator extends VectorValueComparator { + + public ByteComparator() { + super(Byte.SIZE / 8); + } + + @Override + public int compareNotNull(int index1, int index2) { + byte value1 = vector1.get(index1); + byte value2 = vector2.get(index2); + return value1 - value2; + } + + @Override + public VectorValueComparator createNew() { + return new ByteComparator(); + } + } + + /** + * Default comparator for short integers. + * The comparison is based on values, with null comes first. + */ + public static class ShortComparator extends VectorValueComparator { + + public ShortComparator() { + super(Short.SIZE / 8); + } + + @Override + public int compareNotNull(int index1, int index2) { + short value1 = vector1.get(index1); + short value2 = vector2.get(index2); + return value1 - value2; + } + + @Override + public VectorValueComparator createNew() { + return new ShortComparator(); + } + } + + /** + * Default comparator for 32-bit integers. + * The comparison is based on int values, with null comes first. + */ + public static class IntComparator extends VectorValueComparator { + + public IntComparator() { + super(Integer.SIZE / 8); + } + + @Override + public int compareNotNull(int index1, int index2) { + int value1 = vector1.get(index1); + int value2 = vector2.get(index2); + return Integer.compare(value1, value2); + } + + @Override + public VectorValueComparator createNew() { + return new IntComparator(); + } + } + + /** + * Default comparator for long integers. + * The comparison is based on values, with null comes first. + */ + public static class LongComparator extends VectorValueComparator { + + public LongComparator() { + super(Long.SIZE / 8); + } + + @Override + public int compareNotNull(int index1, int index2) { + long value1 = vector1.get(index1); + long value2 = vector2.get(index2); + + return Long.compare(value1, value2); + } + + @Override + public VectorValueComparator createNew() { + return new LongComparator(); + } + } + + /** + * Default comparator for unsigned bytes. + * The comparison is based on values, with null comes first. + */ + public static class UInt1Comparator extends VectorValueComparator { + + public UInt1Comparator() { + super(1); + } + + @Override + public int compareNotNull(int index1, int index2) { + byte value1 = vector1.get(index1); + byte value2 = vector2.get(index2); + + return (value1 & 0xff) - (value2 & 0xff); + } + + @Override + public VectorValueComparator createNew() { + return new UInt1Comparator(); + } + } + + /** + * Default comparator for unsigned short integer. + * The comparison is based on values, with null comes first. + */ + public static class UInt2Comparator extends VectorValueComparator { + + public UInt2Comparator() { + super(2); + } + + @Override + public int compareNotNull(int index1, int index2) { + char value1 = vector1.get(index1); + char value2 = vector2.get(index2); + + // please note that we should not use the built-in + // Character#compare method here, as that method + // essentially compares char values as signed integers. + return (value1 & 0xffff) - (value2 & 0xffff); + } + + @Override + public VectorValueComparator createNew() { + return new UInt2Comparator(); + } + } + + /** + * Default comparator for unsigned integer. + * The comparison is based on values, with null comes first. + */ + public static class UInt4Comparator extends VectorValueComparator { + + public UInt4Comparator() { + super(4); + } + + @Override + public int compareNotNull(int index1, int index2) { + int value1 = vector1.get(index1); + int value2 = vector2.get(index2); + return ByteFunctionHelpers.unsignedIntCompare(value1, value2); + } + + @Override + public VectorValueComparator createNew() { + return new UInt4Comparator(); + } + } + + /** + * Default comparator for unsigned long integer. + * The comparison is based on values, with null comes first. + */ + public static class UInt8Comparator extends VectorValueComparator { + + public UInt8Comparator() { + super(8); + } + + @Override + public int compareNotNull(int index1, int index2) { + long value1 = vector1.get(index1); + long value2 = vector2.get(index2); + return ByteFunctionHelpers.unsignedLongCompare(value1, value2); + } + + @Override + public VectorValueComparator createNew() { + return new UInt8Comparator(); + } + } + + /** + * Default comparator for float type. + * The comparison is based on values, with null comes first. + */ + public static class Float4Comparator extends VectorValueComparator { + + public Float4Comparator() { + super(Float.SIZE / 8); + } + + @Override + public int compareNotNull(int index1, int index2) { + float value1 = vector1.get(index1); + float value2 = vector2.get(index2); + + boolean isNan1 = Float.isNaN(value1); + boolean isNan2 = Float.isNaN(value2); + if (isNan1 || isNan2) { + if (isNan1 && isNan2) { + return 0; + } else if (isNan1) { + // nan is greater than any normal value + return 1; + } else { + return -1; + } + } + + return (int) Math.signum(value1 - value2); + } + + @Override + public VectorValueComparator createNew() { + return new Float4Comparator(); + } + } + + /** + * Default comparator for double type. + * The comparison is based on values, with null comes first. + */ + public static class Float8Comparator extends VectorValueComparator { + + public Float8Comparator() { + super(Double.SIZE / 8); + } + + @Override + public int compareNotNull(int index1, int index2) { + double value1 = vector1.get(index1); + double value2 = vector2.get(index2); + + boolean isNan1 = Double.isNaN(value1); + boolean isNan2 = Double.isNaN(value2); + if (isNan1 || isNan2) { + if (isNan1 && isNan2) { + return 0; + } else if (isNan1) { + // nan is greater than any normal value + return 1; + } else { + return -1; + } + } + + return (int) Math.signum(value1 - value2); + } + + @Override + public VectorValueComparator createNew() { + return new Float8Comparator(); + } + } + + /** + * Default comparator for {@link org.apache.arrow.vector.BaseVariableWidthVector}. + * The comparison is in lexicographic order, with null comes first. + */ + public static class VariableWidthComparator extends VectorValueComparator { + + private ArrowBufPointer reusablePointer1 = new ArrowBufPointer(); + + private ArrowBufPointer reusablePointer2 = new ArrowBufPointer(); + + @Override + public int compare(int index1, int index2) { + vector1.getDataPointer(index1, reusablePointer1); + vector2.getDataPointer(index2, reusablePointer2); + return reusablePointer1.compareTo(reusablePointer2); + } + + @Override + public int compareNotNull(int index1, int index2) { + vector1.getDataPointer(index1, reusablePointer1); + vector2.getDataPointer(index2, reusablePointer2); + return reusablePointer1.compareTo(reusablePointer2); + } + + @Override + public VectorValueComparator createNew() { + return new VariableWidthComparator(); + } + } + + /** + * Default comparator for {@link BaseRepeatedValueVector}. + * It works by comparing the underlying vector in a lexicographic order. + * @param inner vector type. + */ + public static class RepeatedValueComparator + extends VectorValueComparator { + + private VectorValueComparator innerComparator; + + public RepeatedValueComparator(VectorValueComparator innerComparator) { + this.innerComparator = innerComparator; + } + + @Override + public int compareNotNull(int index1, int index2) { + int startIdx1 = vector1.getOffsetBuffer().getInt(index1 * OFFSET_WIDTH); + int startIdx2 = vector2.getOffsetBuffer().getInt(index2 * OFFSET_WIDTH); + + int endIdx1 = vector1.getOffsetBuffer().getInt((index1 + 1) * OFFSET_WIDTH); + int endIdx2 = vector2.getOffsetBuffer().getInt((index2 + 1) * OFFSET_WIDTH); + + int length1 = endIdx1 - startIdx1; + int length2 = endIdx2 - startIdx2; + + int length = length1 < length2 ? length1 : length2; + + for (int i = 0; i < length; i++) { + int result = innerComparator.compare(startIdx1 + i, startIdx2 + i); + if (result != 0) { + return result; + } + } + return length1 - length2; + } + + @Override + public VectorValueComparator createNew() { + VectorValueComparator newInnerComparator = innerComparator.createNew(); + return new RepeatedValueComparator(newInnerComparator); + } + + @Override + public void attachVectors(BaseRepeatedValueVector vector1, BaseRepeatedValueVector vector2) { + this.vector1 = vector1; + this.vector2 = vector2; + + innerComparator.attachVectors((T) vector1.getDataVector(), (T) vector2.getDataVector()); + } + } + + private DefaultVectorComparators() { + } +} diff --git a/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/FixedWidthInPlaceVectorSorter.java b/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/FixedWidthInPlaceVectorSorter.java new file mode 100644 index 000000000..aaa7ba117 --- /dev/null +++ b/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/FixedWidthInPlaceVectorSorter.java @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.algorithm.sort; + +import org.apache.arrow.vector.BaseFixedWidthVector; + +/** + * Default in-place sorter for fixed-width vectors. + * It is based on quick-sort, with average time complexity O(n*log(n)). + * @param vector type. + */ +public class FixedWidthInPlaceVectorSorter implements InPlaceVectorSorter { + + /** + * If the number of items is smaller than this threshold, we will use another algorithm to sort the data. + */ + public static final int CHANGE_ALGORITHM_THRESHOLD = 15; + + static final int STOP_CHOOSING_PIVOT_THRESHOLD = 3; + + VectorValueComparator comparator; + + /** + * The vector to sort. + */ + V vec; + + /** + * The buffer to hold the pivot. + * It always has length 1. + */ + V pivotBuffer; + + @Override + public void sortInPlace(V vec, VectorValueComparator comparator) { + try { + this.vec = vec; + this.comparator = comparator; + this.pivotBuffer = (V) vec.getField().createVector(vec.getAllocator()); + this.pivotBuffer.allocateNew(1); + this.pivotBuffer.setValueCount(1); + + comparator.attachVectors(vec, pivotBuffer); + quickSort(); + } finally { + this.pivotBuffer.close(); + } + } + + private void quickSort() { + try (OffHeapIntStack rangeStack = new OffHeapIntStack(vec.getAllocator())) { + rangeStack.push(0); + rangeStack.push(vec.getValueCount() - 1); + + while (!rangeStack.isEmpty()) { + int high = rangeStack.pop(); + int low = rangeStack.pop(); + if (low < high) { + if (high - low < CHANGE_ALGORITHM_THRESHOLD) { + // switch to insertion sort + InsertionSorter.insertionSort(vec, low, high, comparator, pivotBuffer); + continue; + } + + int mid = partition(low, high); + + // push the larger part to stack first, + // to reduce the required stack size + if (high - mid < mid - low) { + rangeStack.push(low); + rangeStack.push(mid - 1); + + rangeStack.push(mid + 1); + rangeStack.push(high); + } else { + rangeStack.push(mid + 1); + rangeStack.push(high); + + rangeStack.push(low); + rangeStack.push(mid - 1); + } + } + } + } + } + + /** + * Select the pivot as the median of 3 samples. + */ + void choosePivot(int low, int high) { + // we need at least 3 items + if (high - low + 1 < STOP_CHOOSING_PIVOT_THRESHOLD) { + pivotBuffer.copyFrom(low, 0, vec); + return; + } + + comparator.attachVector(vec); + int mid = low + (high - low) / 2; + + // find the median by at most 3 comparisons + int medianIdx; + if (comparator.compare(low, mid) < 0) { + if (comparator.compare(mid, high) < 0) { + medianIdx = mid; + } else { + if (comparator.compare(low, high) < 0) { + medianIdx = high; + } else { + medianIdx = low; + } + } + } else { + if (comparator.compare(mid, high) > 0) { + medianIdx = mid; + } else { + if (comparator.compare(low, high) < 0) { + medianIdx = low; + } else { + medianIdx = high; + } + } + } + + // move the pivot to the low position, if necessary + if (medianIdx != low) { + pivotBuffer.copyFrom(medianIdx, 0, vec); + vec.copyFrom(low, medianIdx, vec); + vec.copyFrom(0, low, pivotBuffer); + } else { + pivotBuffer.copyFrom(low, 0, vec); + } + + comparator.attachVectors(vec, pivotBuffer); + } + + private int partition(int low, int high) { + choosePivot(low, high); + + while (low < high) { + while (low < high && comparator.compare(high, 0) >= 0) { + high -= 1; + } + vec.copyFrom(high, low, vec); + + while (low < high && comparator.compare(low, 0) <= 0) { + low += 1; + } + vec.copyFrom(low, high, vec); + } + + vec.copyFrom(0, low, pivotBuffer); + return low; + } +} diff --git a/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/FixedWidthOutOfPlaceVectorSorter.java b/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/FixedWidthOutOfPlaceVectorSorter.java new file mode 100644 index 000000000..4f6c76657 --- /dev/null +++ b/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/FixedWidthOutOfPlaceVectorSorter.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.algorithm.sort; + +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.util.Preconditions; +import org.apache.arrow.vector.BaseFixedWidthVector; +import org.apache.arrow.vector.BitVectorHelper; +import org.apache.arrow.vector.IntVector; + +import io.netty.util.internal.PlatformDependent; + +/** + * Default out-of-place sorter for fixed-width vectors. + * It is an out-of-place sort, with time complexity O(n*log(n)). + * @param vector type. + */ +public class FixedWidthOutOfPlaceVectorSorter implements OutOfPlaceVectorSorter { + + protected IndexSorter indexSorter = new IndexSorter<>(); + + @Override + public void sortOutOfPlace(V srcVector, V dstVector, VectorValueComparator comparator) { + comparator.attachVector(srcVector); + + int valueWidth = comparator.getValueWidth(); + + // buffers referenced in the sort + ArrowBuf srcValueBuffer = srcVector.getDataBuffer(); + ArrowBuf dstValidityBuffer = dstVector.getValidityBuffer(); + ArrowBuf dstValueBuffer = dstVector.getDataBuffer(); + + // check buffer size + Preconditions.checkArgument(dstValidityBuffer.capacity() * 8 >= srcVector.getValueCount(), + "Not enough capacity for the validity buffer of the dst vector. " + + "Expected capacity %s, actual capacity %s", + (srcVector.getValueCount() + 7) / 8, dstValidityBuffer.capacity()); + Preconditions.checkArgument( + dstValueBuffer.capacity() >= srcVector.getValueCount() * srcVector.getTypeWidth(), + "Not enough capacity for the data buffer of the dst vector. " + + "Expected capacity %s, actual capacity %s", + srcVector.getValueCount() * srcVector.getTypeWidth(), dstValueBuffer.capacity()); + + // sort value indices + try (IntVector sortedIndices = new IntVector("", srcVector.getAllocator())) { + sortedIndices.allocateNew(srcVector.getValueCount()); + sortedIndices.setValueCount(srcVector.getValueCount()); + indexSorter.sort(srcVector, sortedIndices, comparator); + + // copy sorted values to the output vector + for (int dstIndex = 0; dstIndex < sortedIndices.getValueCount(); dstIndex++) { + int srcIndex = sortedIndices.get(dstIndex); + if (srcVector.isNull(srcIndex)) { + BitVectorHelper.unsetBit(dstValidityBuffer, dstIndex); + } else { + BitVectorHelper.setBit(dstValidityBuffer, dstIndex); + PlatformDependent.copyMemory( + srcValueBuffer.memoryAddress() + srcIndex * valueWidth, + dstValueBuffer.memoryAddress() + dstIndex * valueWidth, + valueWidth); + } + } + } + } +} diff --git a/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/InPlaceVectorSorter.java b/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/InPlaceVectorSorter.java new file mode 100644 index 000000000..19817fe76 --- /dev/null +++ b/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/InPlaceVectorSorter.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.algorithm.sort; + +import org.apache.arrow.vector.ValueVector; + +/** + * Basic interface for sorting a vector in-place. + * That is, the sorting is performed by modifying the input vector, + * without creating a new sorted vector. + * + * @param the vector type. + */ +public interface InPlaceVectorSorter { + + /** + * Sort a vector in-place. + * @param vec the vector to sort. + * @param comparator the criteria for sort. + */ + void sortInPlace(V vec, VectorValueComparator comparator); +} diff --git a/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/IndexSorter.java b/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/IndexSorter.java new file mode 100644 index 000000000..3072717f4 --- /dev/null +++ b/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/IndexSorter.java @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.algorithm.sort; + +import java.util.stream.IntStream; + +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.ValueVector; + +/** + * Sorter for the indices of a vector. + * @param vector type. + */ +public class IndexSorter { + + /** + * If the number of items is smaller than this threshold, we will use another algorithm to sort the data. + */ + public static final int CHANGE_ALGORITHM_THRESHOLD = 15; + + /** + * Comparator for vector indices. + */ + private VectorValueComparator comparator; + + /** + * Vector indices to sort. + */ + private IntVector indices; + + /** + * Sorts indices, by quick-sort. Suppose the vector is denoted by v. + * After calling this method, the following relations hold: + * v(indices[0]) <= v(indices[1]) <= ... + * @param vector the vector whose indices need to be sorted. + * @param indices the vector for storing the sorted indices. + * @param comparator the comparator to sort indices. + */ + public void sort(V vector, IntVector indices, VectorValueComparator comparator) { + comparator.attachVector(vector); + + this.indices = indices; + + IntStream.range(0, vector.getValueCount()).forEach(i -> indices.set(i, i)); + + this.comparator = comparator; + + quickSort(); + } + + private void quickSort() { + try (OffHeapIntStack rangeStack = new OffHeapIntStack(indices.getAllocator())) { + rangeStack.push(0); + rangeStack.push(indices.getValueCount() - 1); + + while (!rangeStack.isEmpty()) { + int high = rangeStack.pop(); + int low = rangeStack.pop(); + + if (low < high) { + if (high - low < CHANGE_ALGORITHM_THRESHOLD) { + InsertionSorter.insertionSort(indices, low, high, comparator); + continue; + } + + int mid = partition(low, high, indices, comparator); + + // push the larger part to stack first, + // to reduce the required stack size + if (high - mid < mid - low) { + rangeStack.push(low); + rangeStack.push(mid - 1); + + rangeStack.push(mid + 1); + rangeStack.push(high); + } else { + rangeStack.push(mid + 1); + rangeStack.push(high); + + rangeStack.push(low); + rangeStack.push(mid - 1); + } + } + } + } + } + + /** + * Select the pivot as the median of 3 samples. + */ + static int choosePivot( + int low, int high, IntVector indices, VectorValueComparator comparator) { + // we need at least 3 items + if (high - low + 1 < FixedWidthInPlaceVectorSorter.STOP_CHOOSING_PIVOT_THRESHOLD) { + return indices.get(low); + } + + int mid = low + (high - low) / 2; + + // find the median by at most 3 comparisons + int medianIdx; + if (comparator.compare(indices.get(low), indices.get(mid)) < 0) { + if (comparator.compare(indices.get(mid), indices.get(high)) < 0) { + medianIdx = mid; + } else { + if (comparator.compare(indices.get(low), indices.get(high)) < 0) { + medianIdx = high; + } else { + medianIdx = low; + } + } + } else { + if (comparator.compare(indices.get(mid), indices.get(high)) > 0) { + medianIdx = mid; + } else { + if (comparator.compare(indices.get(low), indices.get(high)) < 0) { + medianIdx = low; + } else { + medianIdx = high; + } + } + } + + // move the pivot to the low position, if necessary + if (medianIdx != low) { + int tmp = indices.get(medianIdx); + indices.set(medianIdx, indices.get(low)); + indices.set(low, tmp); + return tmp; + } else { + return indices.get(low); + } + } + + /** + * Partition a range of values in a vector into two parts, with elements in one part smaller than + * elements from the other part. The partition is based on the element indices, so it does + * not modify the underlying vector. + * @param low the lower bound of the range. + * @param high the upper bound of the range. + * @param indices vector element indices. + * @param comparator criteria for comparison. + * @param the vector type. + * @return the index of the split point. + */ + public static int partition( + int low, int high, IntVector indices, VectorValueComparator comparator) { + int pivotIndex = choosePivot(low, high, indices, comparator); + + while (low < high) { + while (low < high && comparator.compare(indices.get(high), pivotIndex) >= 0) { + high -= 1; + } + indices.set(low, indices.get(high)); + + while (low < high && comparator.compare(indices.get(low), pivotIndex) <= 0) { + low += 1; + } + indices.set(high, indices.get(low)); + } + + indices.set(low, pivotIndex); + return low; + } +} diff --git a/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/InsertionSorter.java b/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/InsertionSorter.java new file mode 100644 index 000000000..dc12a5fef --- /dev/null +++ b/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/InsertionSorter.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.algorithm.sort; + +import org.apache.arrow.vector.BaseFixedWidthVector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.ValueVector; + +/** + * Insertion sorter. + */ +class InsertionSorter { + + /** + * Sorts the range of a vector by insertion sort. + * + * @param vector the vector to be sorted. + * @param startIdx the start index of the range (inclusive). + * @param endIdx the end index of the range (inclusive). + * @param buffer an extra buffer with capacity 1 to hold the current key. + * @param comparator the criteria for vector element comparison. + * @param the vector type. + */ + static void insertionSort( + V vector, int startIdx, int endIdx, VectorValueComparator comparator, V buffer) { + comparator.attachVectors(vector, buffer); + for (int i = startIdx; i <= endIdx; i++) { + buffer.copyFrom(i, 0, vector); + int j = i - 1; + while (j >= startIdx && comparator.compare(j, 0) > 0) { + vector.copyFrom(j, j + 1, vector); + j = j - 1; + } + vector.copyFrom(0, j + 1, buffer); + } + } + + /** + * Sorts the range of vector indices by insertion sort. + * + * @param indices the vector indices. + * @param startIdx the start index of the range (inclusive). + * @param endIdx the end index of the range (inclusive). + * @param comparator the criteria for vector element comparison. + * @param the vector type. + */ + static void insertionSort( + IntVector indices, int startIdx, int endIdx, VectorValueComparator comparator) { + for (int i = startIdx; i <= endIdx; i++) { + int key = indices.get(i); + int j = i - 1; + while (j >= startIdx && comparator.compare(indices.get(j), key) > 0) { + indices.set(j + 1, indices.get(j)); + j = j - 1; + } + indices.set(j + 1, key); + } + } +} diff --git a/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/OffHeapIntStack.java b/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/OffHeapIntStack.java new file mode 100644 index 000000000..df96121f1 --- /dev/null +++ b/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/OffHeapIntStack.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.algorithm.sort; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.IntVector; + +/** + * An off heap implementation of stack with int elements. + */ +class OffHeapIntStack implements AutoCloseable { + + private static final int INIT_SIZE = 128; + + private IntVector intVector; + + private int top = 0; + + public OffHeapIntStack(BufferAllocator allocator) { + intVector = new IntVector("int stack inner vector", allocator); + intVector.allocateNew(INIT_SIZE); + intVector.setValueCount(INIT_SIZE); + } + + public void push(int value) { + if (top == intVector.getValueCount()) { + int targetCapacity = intVector.getValueCount() * 2; + while (intVector.getValueCapacity() < targetCapacity) { + intVector.reAlloc(); + } + intVector.setValueCount(targetCapacity); + } + + intVector.set(top++, value); + } + + public int pop() { + return intVector.get(--top); + } + + public int getTop() { + return intVector.get(top - 1); + } + + public boolean isEmpty() { + return top == 0; + } + + public int getCount() { + return top; + } + + @Override + public void close() { + intVector.close(); + } +} diff --git a/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/OutOfPlaceVectorSorter.java b/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/OutOfPlaceVectorSorter.java new file mode 100644 index 000000000..41d6dadc4 --- /dev/null +++ b/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/OutOfPlaceVectorSorter.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.algorithm.sort; + +import org.apache.arrow.vector.ValueVector; + +/** + * Basic interface for sorting a vector out-of-place. + * That is, the sorting is performed on a newly-created vector, + * and the original vector is not modified. + * @param the vector type. + */ +public interface OutOfPlaceVectorSorter { + + /** + * Sort a vector out-of-place. + * @param inVec the input vector. + * @param outVec the output vector, which has the same size as the input vector. + * @param comparator the criteria for sort. + */ + void sortOutOfPlace(V inVec, V outVec, VectorValueComparator comparator); +} diff --git a/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/StableVectorComparator.java b/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/StableVectorComparator.java new file mode 100644 index 000000000..0b0c3bd55 --- /dev/null +++ b/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/StableVectorComparator.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.algorithm.sort; + +import org.apache.arrow.util.Preconditions; +import org.apache.arrow.vector.ValueVector; + +/** + * Stable sorter. It compares values like ordinary comparators. + * However, when values are equal, it breaks ties by the value indices. + * Therefore, sort algorithms using this comparator always produce + * stable sort results. + * @param type of the vector. + */ +public class StableVectorComparator extends VectorValueComparator { + + private final VectorValueComparator innerComparator; + + /** + * Constructs a stable comparator from a given comparator. + * @param innerComparator the comparator to convert to stable comparator.. + */ + public StableVectorComparator(VectorValueComparator innerComparator) { + this.innerComparator = innerComparator; + } + + @Override + public void attachVector(V vector) { + super.attachVector(vector); + innerComparator.attachVector(vector); + } + + @Override + public void attachVectors(V vector1, V vector2) { + Preconditions.checkArgument(vector1 == vector2, + "Stable comparator only supports comparing values from the same vector"); + super.attachVectors(vector1, vector2); + innerComparator.attachVectors(vector1, vector2); + } + + @Override + public int compareNotNull(int index1, int index2) { + int result = innerComparator.compare(index1, index2); + return result != 0 ? result : index1 - index2; + } + + @Override + public VectorValueComparator createNew() { + return new StableVectorComparator(innerComparator.createNew()); + } +} diff --git a/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/VariableWidthOutOfPlaceVectorSorter.java b/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/VariableWidthOutOfPlaceVectorSorter.java new file mode 100644 index 000000000..62003752e --- /dev/null +++ b/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/VariableWidthOutOfPlaceVectorSorter.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.algorithm.sort; + +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.util.Preconditions; +import org.apache.arrow.vector.BaseVariableWidthVector; +import org.apache.arrow.vector.BitVectorHelper; +import org.apache.arrow.vector.IntVector; + +import io.netty.util.internal.PlatformDependent; + +/** + * Default sorter for variable-width vectors. + * It is an out-of-place sort, with time complexity O(n*log(n)). + * @param vector type. + */ +public class VariableWidthOutOfPlaceVectorSorter + implements OutOfPlaceVectorSorter { + + protected IndexSorter indexSorter = new IndexSorter<>(); + + @Override + public void sortOutOfPlace(V srcVector, V dstVector, VectorValueComparator comparator) { + comparator.attachVector(srcVector); + + // buffers referenced in the sort + ArrowBuf srcValueBuffer = srcVector.getDataBuffer(); + ArrowBuf srcOffsetBuffer = srcVector.getOffsetBuffer(); + ArrowBuf dstValidityBuffer = dstVector.getValidityBuffer(); + ArrowBuf dstValueBuffer = dstVector.getDataBuffer(); + ArrowBuf dstOffsetBuffer = dstVector.getOffsetBuffer(); + + // check buffer size + Preconditions.checkArgument(dstValidityBuffer.capacity() * 8 >= srcVector.getValueCount(), + "Not enough capacity for the validity buffer of the dst vector. " + + "Expected capacity %s, actual capacity %s", + (srcVector.getValueCount() + 7) / 8, dstValidityBuffer.capacity()); + Preconditions.checkArgument( + dstOffsetBuffer.capacity() >= (srcVector.getValueCount() + 1) * BaseVariableWidthVector.OFFSET_WIDTH, + "Not enough capacity for the offset buffer of the dst vector. " + + "Expected capacity %s, actual capacity %s", + (srcVector.getValueCount() + 1) * BaseVariableWidthVector.OFFSET_WIDTH, dstOffsetBuffer.capacity()); + long dataSize = srcVector.getOffsetBuffer().getInt( + srcVector.getValueCount() * BaseVariableWidthVector.OFFSET_WIDTH); + Preconditions.checkArgument( + dstValueBuffer.capacity() >= dataSize, "No enough capacity for the data buffer of the dst vector. " + + "Expected capacity %s, actual capacity %s", dataSize, dstValueBuffer.capacity()); + + // sort value indices + try (IntVector sortedIndices = new IntVector("", srcVector.getAllocator())) { + sortedIndices.allocateNew(srcVector.getValueCount()); + sortedIndices.setValueCount(srcVector.getValueCount()); + indexSorter.sort(srcVector, sortedIndices, comparator); + + int dstOffset = 0; + dstOffsetBuffer.setInt(0, 0); + + // copy sorted values to the output vector + for (int dstIndex = 0; dstIndex < sortedIndices.getValueCount(); dstIndex++) { + int srcIndex = sortedIndices.get(dstIndex); + if (srcVector.isNull(srcIndex)) { + BitVectorHelper.unsetBit(dstValidityBuffer, dstIndex); + } else { + BitVectorHelper.setBit(dstValidityBuffer, dstIndex); + int srcOffset = srcOffsetBuffer.getInt(srcIndex * BaseVariableWidthVector.OFFSET_WIDTH); + int valueLength = srcOffsetBuffer.getInt((srcIndex + 1) * BaseVariableWidthVector.OFFSET_WIDTH) - srcOffset; + PlatformDependent.copyMemory( + srcValueBuffer.memoryAddress() + srcOffset, + dstValueBuffer.memoryAddress() + dstOffset, + valueLength); + dstOffset += valueLength; + } + dstOffsetBuffer.setInt((dstIndex + 1) * BaseVariableWidthVector.OFFSET_WIDTH, dstOffset); + } + } + } +} diff --git a/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/VectorValueComparator.java b/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/VectorValueComparator.java new file mode 100644 index 000000000..ed32e16ca --- /dev/null +++ b/src/arrow/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/VectorValueComparator.java @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.algorithm.sort; + +import org.apache.arrow.vector.ValueVector; + +/** + * Compare two values at the given indices in the vectors. + * This is used for vector sorting. + * @param type of the vector. + */ +public abstract class VectorValueComparator { + + /** + * The first vector to compare. + */ + protected V vector1; + + /** + * The second vector to compare. + */ + protected V vector2; + + /** + * Width of the vector value. For variable-length vectors, this value makes no sense. + */ + protected int valueWidth; + + /** + * Constructor for variable-width vectors. + */ + protected VectorValueComparator() { + + } + + /** + * Constructor for fixed-width vectors. + * @param valueWidth the record width (in bytes). + */ + protected VectorValueComparator(int valueWidth) { + this.valueWidth = valueWidth; + } + + public int getValueWidth() { + return valueWidth; + } + + /** + * Attach both vectors to compare to the same input vector. + * @param vector the vector to attach. + */ + public void attachVector(V vector) { + attachVectors(vector, vector); + } + + /** + * Attach vectors to compare. + * @param vector1 the first vector to compare. + * @param vector2 the second vector to compare. + */ + public void attachVectors(V vector1, V vector2) { + this.vector1 = vector1; + this.vector2 = vector2; + } + + /** + * Compare two values, given their indices. + * @param index1 index of the first value to compare. + * @param index2 index of the second value to compare. + * @return an integer greater than 0, if the first value is greater; + * an integer smaller than 0, if the first value is smaller; or 0, if both + * values are equal. + */ + public int compare(int index1, int index2) { + boolean isNull1 = vector1.isNull(index1); + boolean isNull2 = vector2.isNull(index2); + + if (isNull1 || isNull2) { + if (isNull1 && isNull2) { + return 0; + } else if (isNull1) { + // null is smaller + return -1; + } else { + return 1; + } + } + return compareNotNull(index1, index2); + } + + /** + * Compare two values, given their indices. + * This is a fast path for comparing non-null values, so the caller + * must make sure that values at both indices are not null. + * @param index1 index of the first value to compare. + * @param index2 index of the second value to compare. + * @return an integer greater than 0, if the first value is greater; + * an integer smaller than 0, if the first value is smaller; or 0, if both + * values are equal. + */ + public abstract int compareNotNull(int index1, int index2); + + /** + * Creates a comparator of the same type. + * @return the newly created comparator. + */ + public abstract VectorValueComparator createNew(); +} diff --git a/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/deduplicate/TestDeduplicationUtils.java b/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/deduplicate/TestDeduplicationUtils.java new file mode 100644 index 000000000..def83fba7 --- /dev/null +++ b/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/deduplicate/TestDeduplicationUtils.java @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.algorithm.deduplicate; + +import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; + +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.BitVectorHelper; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.util.DataSizeRoundingUtil; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +/** + * Test cases for {@link DeduplicationUtils}. + */ +public class TestDeduplicationUtils { + + private static final int VECTOR_LENGTH = 100; + + private static final int REPETITION_COUNT = 3; + + private BufferAllocator allocator; + + @Before + public void prepare() { + allocator = new RootAllocator(1024 * 1024); + } + + @After + public void shutdown() { + allocator.close(); + } + + @Test + public void testDeduplicateFixedWidth() { + try (IntVector origVec = new IntVector("original vec", allocator); + IntVector dedupVec = new IntVector("deduplicated vec", allocator); + IntVector lengthVec = new IntVector("length vec", allocator); + ArrowBuf distinctBuf = allocator.buffer( + DataSizeRoundingUtil.divideBy8Ceil(VECTOR_LENGTH * REPETITION_COUNT))) { + origVec.allocateNew(VECTOR_LENGTH * REPETITION_COUNT); + origVec.setValueCount(VECTOR_LENGTH * REPETITION_COUNT); + lengthVec.allocateNew(); + + // prepare data + for (int i = 0; i < VECTOR_LENGTH; i++) { + for (int j = 0; j < REPETITION_COUNT; j++) { + origVec.set(i * REPETITION_COUNT + j, i); + } + } + + DeduplicationUtils.populateRunStartIndicators(origVec, distinctBuf); + assertEquals( VECTOR_LENGTH, + VECTOR_LENGTH * REPETITION_COUNT - + BitVectorHelper.getNullCount(distinctBuf, VECTOR_LENGTH * REPETITION_COUNT)); + + DeduplicationUtils.populateDeduplicatedValues(distinctBuf, origVec, dedupVec); + assertEquals(VECTOR_LENGTH, dedupVec.getValueCount()); + + for (int i = 0; i < VECTOR_LENGTH; i++) { + assertEquals(i, dedupVec.get(i)); + } + + DeduplicationUtils.populateRunLengths(distinctBuf, lengthVec, VECTOR_LENGTH * REPETITION_COUNT); + assertEquals(VECTOR_LENGTH, lengthVec.getValueCount()); + + for (int i = 0; i < VECTOR_LENGTH; i++) { + assertEquals(REPETITION_COUNT, lengthVec.get(i)); + } + } + } + + @Test + public void testDeduplicateVariableWidth() { + try (VarCharVector origVec = new VarCharVector("original vec", allocator); + VarCharVector dedupVec = new VarCharVector("deduplicated vec", allocator); + IntVector lengthVec = new IntVector("length vec", allocator); + ArrowBuf distinctBuf = allocator.buffer( + DataSizeRoundingUtil.divideBy8Ceil(VECTOR_LENGTH * REPETITION_COUNT))) { + origVec.allocateNew( + VECTOR_LENGTH * REPETITION_COUNT * 10, VECTOR_LENGTH * REPETITION_COUNT); + origVec.setValueCount(VECTOR_LENGTH * REPETITION_COUNT); + lengthVec.allocateNew(); + + // prepare data + for (int i = 0; i < VECTOR_LENGTH; i++) { + String str = String.valueOf(i * i); + for (int j = 0; j < REPETITION_COUNT; j++) { + origVec.set(i * REPETITION_COUNT + j, str.getBytes()); + } + } + + DeduplicationUtils.populateRunStartIndicators(origVec, distinctBuf); + assertEquals(VECTOR_LENGTH, + VECTOR_LENGTH * REPETITION_COUNT - + BitVectorHelper.getNullCount(distinctBuf, VECTOR_LENGTH * REPETITION_COUNT)); + + DeduplicationUtils.populateDeduplicatedValues(distinctBuf, origVec, dedupVec); + assertEquals(VECTOR_LENGTH, dedupVec.getValueCount()); + + for (int i = 0; i < VECTOR_LENGTH; i++) { + assertArrayEquals(String.valueOf(i * i).getBytes(), dedupVec.get(i)); + } + + DeduplicationUtils.populateRunLengths( + distinctBuf, lengthVec, VECTOR_LENGTH * REPETITION_COUNT); + assertEquals(VECTOR_LENGTH, lengthVec.getValueCount()); + + for (int i = 0; i < VECTOR_LENGTH; i++) { + assertEquals(REPETITION_COUNT, lengthVec.get(i)); + } + } + } +} diff --git a/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/deduplicate/TestVectorRunDeduplicator.java b/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/deduplicate/TestVectorRunDeduplicator.java new file mode 100644 index 000000000..4bfa6e255 --- /dev/null +++ b/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/deduplicate/TestVectorRunDeduplicator.java @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.algorithm.deduplicate; + +import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.VarCharVector; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +/** + * Test cases for {@link VectorRunDeduplicator}. + */ +public class TestVectorRunDeduplicator { + + private static final int VECTOR_LENGTH = 100; + + private static final int REPETITION_COUNT = 3; + + private BufferAllocator allocator; + + @Before + public void prepare() { + allocator = new RootAllocator(1024 * 1024); + } + + @After + public void shutdown() { + allocator.close(); + } + + @Test + public void testDeduplicateFixedWidth() { + try (IntVector origVec = new IntVector("original vec", allocator); + IntVector dedupVec = new IntVector("deduplicated vec", allocator); + IntVector lengthVec = new IntVector("length vec", allocator); + VectorRunDeduplicator deduplicator = + new VectorRunDeduplicator<>(origVec, allocator)) { + origVec.allocateNew(VECTOR_LENGTH * REPETITION_COUNT); + origVec.setValueCount(VECTOR_LENGTH * REPETITION_COUNT); + lengthVec.allocateNew(); + + // prepare data + for (int i = 0; i < VECTOR_LENGTH; i++) { + for (int j = 0; j < REPETITION_COUNT; j++) { + origVec.set(i * REPETITION_COUNT + j, i); + } + } + + int distinctCount = deduplicator.getRunCount(); + assertEquals(VECTOR_LENGTH, distinctCount); + + dedupVec.allocateNew(distinctCount); + + deduplicator.populateDeduplicatedValues(dedupVec); + assertEquals(VECTOR_LENGTH, dedupVec.getValueCount()); + + for (int i = 0; i < VECTOR_LENGTH; i++) { + assertEquals(i, dedupVec.get(i)); + } + + deduplicator.populateRunLengths(lengthVec); + assertEquals(VECTOR_LENGTH, lengthVec.getValueCount()); + + for (int i = 0; i < VECTOR_LENGTH; i++) { + assertEquals(REPETITION_COUNT, lengthVec.get(i)); + } + } + } + + @Test + public void testDeduplicateVariableWidth() { + try (VarCharVector origVec = new VarCharVector("original vec", allocator); + VarCharVector dedupVec = new VarCharVector("deduplicated vec", allocator); + IntVector lengthVec = new IntVector("length vec", allocator); + VectorRunDeduplicator deduplicator = + new VectorRunDeduplicator<>(origVec, allocator)) { + origVec.allocateNew( + VECTOR_LENGTH * REPETITION_COUNT * 10, VECTOR_LENGTH * REPETITION_COUNT); + origVec.setValueCount(VECTOR_LENGTH * REPETITION_COUNT); + lengthVec.allocateNew(); + + // prepare data + for (int i = 0; i < VECTOR_LENGTH; i++) { + String str = String.valueOf(i * i); + for (int j = 0; j < REPETITION_COUNT; j++) { + origVec.set(i * REPETITION_COUNT + j, str.getBytes()); + } + } + + int distinctCount = deduplicator.getRunCount(); + assertEquals(VECTOR_LENGTH, distinctCount); + + dedupVec.allocateNew(distinctCount * 10, distinctCount); + + deduplicator.populateDeduplicatedValues(dedupVec); + assertEquals(VECTOR_LENGTH, dedupVec.getValueCount()); + + for (int i = 0; i < VECTOR_LENGTH; i++) { + assertArrayEquals(String.valueOf(i * i).getBytes(), dedupVec.get(i)); + } + + deduplicator.populateRunLengths(lengthVec); + assertEquals(VECTOR_LENGTH, lengthVec.getValueCount()); + + for (int i = 0; i < VECTOR_LENGTH; i++) { + assertEquals(REPETITION_COUNT, lengthVec.get(i)); + } + } + } +} diff --git a/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/dictionary/TestHashTableBasedDictionaryBuilder.java b/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/dictionary/TestHashTableBasedDictionaryBuilder.java new file mode 100644 index 000000000..0a3314535 --- /dev/null +++ b/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/dictionary/TestHashTableBasedDictionaryBuilder.java @@ -0,0 +1,202 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.algorithm.dictionary; + +import static junit.framework.TestCase.assertTrue; +import static org.junit.Assert.assertNull; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.VarCharVector; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +/** + * Test cases for {@link HashTableBasedDictionaryBuilder}. + */ +public class TestHashTableBasedDictionaryBuilder { + + private BufferAllocator allocator; + + @Before + public void prepare() { + allocator = new RootAllocator(1024 * 1024); + } + + @After + public void shutdown() { + allocator.close(); + } + + @Test + public void testBuildVariableWidthDictionaryWithNull() { + try (VarCharVector vec = new VarCharVector("", allocator); + VarCharVector dictionary = new VarCharVector("", allocator)) { + + vec.allocateNew(100, 10); + vec.setValueCount(10); + + dictionary.allocateNew(); + + // fill data + vec.set(0, "hello".getBytes()); + vec.set(1, "abc".getBytes()); + vec.setNull(2); + vec.set(3, "world".getBytes()); + vec.set(4, "12".getBytes()); + vec.set(5, "dictionary".getBytes()); + vec.setNull(6); + vec.set(7, "hello".getBytes()); + vec.set(8, "good".getBytes()); + vec.set(9, "abc".getBytes()); + + HashTableBasedDictionaryBuilder dictionaryBuilder = + new HashTableBasedDictionaryBuilder<>(dictionary, true); + + int result = dictionaryBuilder.addValues(vec); + + assertEquals(7, result); + assertEquals(7, dictionary.getValueCount()); + + assertEquals("hello", new String(dictionary.get(0))); + assertEquals("abc", new String(dictionary.get(1))); + assertNull(dictionary.get(2)); + assertEquals("world", new String(dictionary.get(3))); + assertEquals("12", new String(dictionary.get(4))); + assertEquals("dictionary", new String(dictionary.get(5))); + assertEquals("good", new String(dictionary.get(6))); + } + } + + @Test + public void testBuildVariableWidthDictionaryWithoutNull() { + try (VarCharVector vec = new VarCharVector("", allocator); + VarCharVector dictionary = new VarCharVector("", allocator)) { + + vec.allocateNew(100, 10); + vec.setValueCount(10); + + dictionary.allocateNew(); + + // fill data + vec.set(0, "hello".getBytes()); + vec.set(1, "abc".getBytes()); + vec.setNull(2); + vec.set(3, "world".getBytes()); + vec.set(4, "12".getBytes()); + vec.set(5, "dictionary".getBytes()); + vec.setNull(6); + vec.set(7, "hello".getBytes()); + vec.set(8, "good".getBytes()); + vec.set(9, "abc".getBytes()); + + HashTableBasedDictionaryBuilder dictionaryBuilder = + new HashTableBasedDictionaryBuilder<>(dictionary, false); + + int result = dictionaryBuilder.addValues(vec); + + assertEquals(6, result); + assertEquals(6, dictionary.getValueCount()); + + assertEquals("hello", new String(dictionary.get(0))); + assertEquals("abc", new String(dictionary.get(1))); + assertEquals("world", new String(dictionary.get(2))); + assertEquals("12", new String(dictionary.get(3))); + assertEquals("dictionary", new String(dictionary.get(4))); + assertEquals("good", new String(dictionary.get(5))); + + } + } + + @Test + public void testBuildFixedWidthDictionaryWithNull() { + try (IntVector vec = new IntVector("", allocator); + IntVector dictionary = new IntVector("", allocator)) { + vec.allocateNew(10); + vec.setValueCount(10); + + dictionary.allocateNew(); + + // fill data + vec.set(0, 4); + vec.set(1, 8); + vec.set(2, 32); + vec.set(3, 8); + vec.set(4, 16); + vec.set(5, 32); + vec.setNull(6); + vec.set(7, 4); + vec.set(8, 4); + vec.setNull(9); + + HashTableBasedDictionaryBuilder dictionaryBuilder = + new HashTableBasedDictionaryBuilder<>(dictionary, true); + + int result = dictionaryBuilder.addValues(vec); + + assertEquals(5, result); + assertEquals(5, dictionary.getValueCount()); + + assertEquals(4, dictionary.get(0)); + assertEquals(8, dictionary.get(1)); + assertEquals(32, dictionary.get(2)); + assertEquals(16, dictionary.get(3)); + assertTrue(dictionary.isNull(4)); + } + } + + @Test + public void testBuildFixedWidthDictionaryWithoutNull() { + try (IntVector vec = new IntVector("", allocator); + IntVector dictionary = new IntVector("", allocator)) { + vec.allocateNew(10); + vec.setValueCount(10); + + dictionary.allocateNew(); + + // fill data + vec.set(0, 4); + vec.set(1, 8); + vec.set(2, 32); + vec.set(3, 8); + vec.set(4, 16); + vec.set(5, 32); + vec.setNull(6); + vec.set(7, 4); + vec.set(8, 4); + vec.setNull(9); + + HashTableBasedDictionaryBuilder dictionaryBuilder = + new HashTableBasedDictionaryBuilder<>(dictionary, false); + + int result = dictionaryBuilder.addValues(vec); + + assertEquals(4, result); + assertEquals(4, dictionary.getValueCount()); + + assertEquals(4, dictionary.get(0)); + assertEquals(8, dictionary.get(1)); + assertEquals(32, dictionary.get(2)); + assertEquals(16, dictionary.get(3)); + + } + } +} diff --git a/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/dictionary/TestHashTableDictionaryEncoder.java b/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/dictionary/TestHashTableDictionaryEncoder.java new file mode 100644 index 000000000..dd22ac96f --- /dev/null +++ b/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/dictionary/TestHashTableDictionaryEncoder.java @@ -0,0 +1,350 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.algorithm.dictionary; + +import static junit.framework.TestCase.assertTrue; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.Random; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.VarBinaryVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.DictionaryEncoder; +import org.apache.arrow.vector.types.pojo.DictionaryEncoding; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +/** + * Test cases for {@link HashTableDictionaryEncoder}. + */ +public class TestHashTableDictionaryEncoder { + + private final int VECTOR_LENGTH = 50; + + private final int DICTIONARY_LENGTH = 10; + + private BufferAllocator allocator; + + byte[] zero = "000".getBytes(StandardCharsets.UTF_8); + byte[] one = "111".getBytes(StandardCharsets.UTF_8); + byte[] two = "222".getBytes(StandardCharsets.UTF_8); + + byte[][] data = new byte[][]{zero, one, two}; + + @Before + public void prepare() { + allocator = new RootAllocator(1024 * 1024); + } + + @After + public void shutdown() { + allocator.close(); + } + + @Test + public void testEncodeAndDecode() { + Random random = new Random(); + try (VarCharVector rawVector = new VarCharVector("original vector", allocator); + IntVector encodedVector = new IntVector("encoded vector", allocator); + VarCharVector dictionary = new VarCharVector("dictionary", allocator)) { + + // set up dictionary + dictionary.allocateNew(); + for (int i = 0; i < DICTIONARY_LENGTH; i++) { + // encode "i" as i + dictionary.setSafe(i, String.valueOf(i).getBytes()); + } + dictionary.setValueCount(DICTIONARY_LENGTH); + + // set up raw vector + rawVector.allocateNew(10 * VECTOR_LENGTH, VECTOR_LENGTH); + for (int i = 0; i < VECTOR_LENGTH; i++) { + int val = (random.nextInt() & Integer.MAX_VALUE) % DICTIONARY_LENGTH; + rawVector.set(i, String.valueOf(val).getBytes()); + } + rawVector.setValueCount(VECTOR_LENGTH); + + HashTableDictionaryEncoder encoder = + new HashTableDictionaryEncoder<>(dictionary, false); + + // perform encoding + encodedVector.allocateNew(); + encoder.encode(rawVector, encodedVector); + + // verify encoding results + assertEquals(rawVector.getValueCount(), encodedVector.getValueCount()); + for (int i = 0; i < VECTOR_LENGTH; i++) { + assertArrayEquals(rawVector.get(i), String.valueOf(encodedVector.get(i)).getBytes()); + } + + // perform decoding + Dictionary dict = new Dictionary(dictionary, new DictionaryEncoding(1L, false, null)); + try (VarCharVector decodedVector = (VarCharVector) DictionaryEncoder.decode(encodedVector, dict)) { + + // verify decoding results + assertEquals(encodedVector.getValueCount(), decodedVector.getValueCount()); + for (int i = 0; i < VECTOR_LENGTH; i++) { + assertArrayEquals(String.valueOf(encodedVector.get(i)).getBytes(), decodedVector.get(i)); + } + } + } + } + + @Test + public void testEncodeAndDecodeWithNull() { + Random random = new Random(); + try (VarCharVector rawVector = new VarCharVector("original vector", allocator); + IntVector encodedVector = new IntVector("encoded vector", allocator); + VarCharVector dictionary = new VarCharVector("dictionary", allocator)) { + + // set up dictionary + dictionary.allocateNew(); + dictionary.setNull(0); + for (int i = 1; i < DICTIONARY_LENGTH; i++) { + // encode "i" as i + dictionary.setSafe(i, String.valueOf(i).getBytes()); + } + dictionary.setValueCount(DICTIONARY_LENGTH); + + // set up raw vector + rawVector.allocateNew(10 * VECTOR_LENGTH, VECTOR_LENGTH); + for (int i = 0; i < VECTOR_LENGTH; i++) { + if (i % 10 == 0) { + rawVector.setNull(i); + } else { + int val = (random.nextInt() & Integer.MAX_VALUE) % (DICTIONARY_LENGTH - 1) + 1; + rawVector.set(i, String.valueOf(val).getBytes()); + } + } + rawVector.setValueCount(VECTOR_LENGTH); + + HashTableDictionaryEncoder encoder = + new HashTableDictionaryEncoder<>(dictionary, true); + + // perform encoding + encodedVector.allocateNew(); + encoder.encode(rawVector, encodedVector); + + // verify encoding results + assertEquals(rawVector.getValueCount(), encodedVector.getValueCount()); + for (int i = 0; i < VECTOR_LENGTH; i++) { + if (i % 10 == 0) { + assertEquals(0, encodedVector.get(i)); + } else { + assertArrayEquals(rawVector.get(i), String.valueOf(encodedVector.get(i)).getBytes()); + } + } + + // perform decoding + Dictionary dict = new Dictionary(dictionary, new DictionaryEncoding(1L, false, null)); + try (VarCharVector decodedVector = (VarCharVector) DictionaryEncoder.decode(encodedVector, dict)) { + // verify decoding results + assertEquals(encodedVector.getValueCount(), decodedVector.getValueCount()); + for (int i = 0; i < VECTOR_LENGTH; i++) { + if (i % 10 == 0) { + assertTrue(decodedVector.isNull(i)); + } else { + assertArrayEquals(String.valueOf(encodedVector.get(i)).getBytes(), decodedVector.get(i)); + } + } + } + } + } + + @Test + public void testEncodeNullWithoutNullInDictionary() { + try (VarCharVector rawVector = new VarCharVector("original vector", allocator); + IntVector encodedVector = new IntVector("encoded vector", allocator); + VarCharVector dictionary = new VarCharVector("dictionary", allocator)) { + + // set up dictionary, with no null in it. + dictionary.allocateNew(); + for (int i = 0; i < DICTIONARY_LENGTH; i++) { + // encode "i" as i + dictionary.setSafe(i, String.valueOf(i).getBytes()); + } + dictionary.setValueCount(DICTIONARY_LENGTH); + + // the vector to encode has a null inside. + rawVector.allocateNew(1); + rawVector.setNull(0); + rawVector.setValueCount(1); + + encodedVector.allocateNew(); + + HashTableDictionaryEncoder encoder = + new HashTableDictionaryEncoder<>(dictionary, true); + + // the encoder should encode null, but no null in the dictionary, + // so an exception should be thrown. + assertThrows(IllegalArgumentException.class, () -> { + encoder.encode(rawVector, encodedVector); + }); + } + } + + @Test + public void testEncodeStrings() { + // Create a new value vector + try (final VarCharVector vector = new VarCharVector("foo", allocator); + final IntVector encoded = new IntVector("encoded", allocator); + final VarCharVector dictionaryVector = new VarCharVector("dict", allocator)) { + + vector.allocateNew(512, 5); + encoded.allocateNew(); + + // set some values + vector.setSafe(0, zero, 0, zero.length); + vector.setSafe(1, one, 0, one.length); + vector.setSafe(2, one, 0, one.length); + vector.setSafe(3, two, 0, two.length); + vector.setSafe(4, zero, 0, zero.length); + vector.setValueCount(5); + + // set some dictionary values + dictionaryVector.allocateNew(512, 3); + dictionaryVector.setSafe(0, zero, 0, one.length); + dictionaryVector.setSafe(1, one, 0, two.length); + dictionaryVector.setSafe(2, two, 0, zero.length); + dictionaryVector.setValueCount(3); + + HashTableDictionaryEncoder encoder = + new HashTableDictionaryEncoder<>(dictionaryVector); + encoder.encode(vector, encoded); + + // verify indices + assertEquals(5, encoded.getValueCount()); + assertEquals(0, encoded.get(0)); + assertEquals(1, encoded.get(1)); + assertEquals(1, encoded.get(2)); + assertEquals(2, encoded.get(3)); + assertEquals(0, encoded.get(4)); + + // now run through the decoder and verify we get the original back + Dictionary dict = new Dictionary(dictionaryVector, new DictionaryEncoding(1L, false, null)); + try (VarCharVector decoded = (VarCharVector) DictionaryEncoder.decode(encoded, dict)) { + + assertEquals(vector.getValueCount(), decoded.getValueCount()); + for (int i = 0; i < 5; i++) { + assertEquals(vector.getObject(i), decoded.getObject(i)); + } + } + } + } + + @Test + public void testEncodeLargeVector() { + // Create a new value vector + try (final VarCharVector vector = new VarCharVector("foo", allocator); + final IntVector encoded = new IntVector("encoded", allocator); + final VarCharVector dictionaryVector = new VarCharVector("dict", allocator)) { + vector.allocateNew(); + encoded.allocateNew(); + + int count = 10000; + + for (int i = 0; i < 10000; ++i) { + vector.setSafe(i, data[i % 3], 0, data[i % 3].length); + } + vector.setValueCount(count); + + dictionaryVector.allocateNew(512, 3); + dictionaryVector.setSafe(0, zero, 0, one.length); + dictionaryVector.setSafe(1, one, 0, two.length); + dictionaryVector.setSafe(2, two, 0, zero.length); + dictionaryVector.setValueCount(3); + + HashTableDictionaryEncoder encoder = + new HashTableDictionaryEncoder<>(dictionaryVector); + encoder.encode(vector, encoded); + + assertEquals(count, encoded.getValueCount()); + for (int i = 0; i < count; ++i) { + assertEquals(i % 3, encoded.get(i)); + } + + // now run through the decoder and verify we get the original back + Dictionary dict = new Dictionary(dictionaryVector, new DictionaryEncoding(1L, false, null)); + try (VarCharVector decoded = (VarCharVector) DictionaryEncoder.decode(encoded, dict)) { + assertEquals(vector.getClass(), decoded.getClass()); + assertEquals(vector.getValueCount(), decoded.getValueCount()); + for (int i = 0; i < count; ++i) { + assertEquals(vector.getObject(i), decoded.getObject(i)); + } + } + } + } + + @Test + public void testEncodeBinaryVector() { + // Create a new value vector + try (final VarBinaryVector vector = new VarBinaryVector("foo", allocator); + final VarBinaryVector dictionaryVector = new VarBinaryVector("dict", allocator); + final IntVector encoded = new IntVector("encoded", allocator)) { + vector.allocateNew(512, 5); + vector.allocateNew(); + encoded.allocateNew(); + + // set some values + vector.setSafe(0, zero, 0, zero.length); + vector.setSafe(1, one, 0, one.length); + vector.setSafe(2, one, 0, one.length); + vector.setSafe(3, two, 0, two.length); + vector.setSafe(4, zero, 0, zero.length); + vector.setValueCount(5); + + // set some dictionary values + dictionaryVector.allocateNew(512, 3); + dictionaryVector.setSafe(0, zero, 0, one.length); + dictionaryVector.setSafe(1, one, 0, two.length); + dictionaryVector.setSafe(2, two, 0, zero.length); + dictionaryVector.setValueCount(3); + + HashTableDictionaryEncoder encoder = + new HashTableDictionaryEncoder<>(dictionaryVector); + encoder.encode(vector, encoded); + + assertEquals(5, encoded.getValueCount()); + assertEquals(0, encoded.get(0)); + assertEquals(1, encoded.get(1)); + assertEquals(1, encoded.get(2)); + assertEquals(2, encoded.get(3)); + assertEquals(0, encoded.get(4)); + + // now run through the decoder and verify we get the original back + Dictionary dict = new Dictionary(dictionaryVector, new DictionaryEncoding(1L, false, null)); + try (VarBinaryVector decoded = (VarBinaryVector) DictionaryEncoder.decode(encoded, dict)) { + + assertEquals(vector.getClass(), decoded.getClass()); + assertEquals(vector.getValueCount(), decoded.getValueCount()); + for (int i = 0; i < 5; i++) { + assertTrue(Arrays.equals(vector.getObject(i), decoded.getObject(i))); + } + } + } + } +} diff --git a/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/dictionary/TestLinearDictionaryEncoder.java b/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/dictionary/TestLinearDictionaryEncoder.java new file mode 100644 index 000000000..104d1b35b --- /dev/null +++ b/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/dictionary/TestLinearDictionaryEncoder.java @@ -0,0 +1,350 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.algorithm.dictionary; + +import static junit.framework.TestCase.assertTrue; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.Random; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.VarBinaryVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.DictionaryEncoder; +import org.apache.arrow.vector.types.pojo.DictionaryEncoding; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +/** + * Test cases for {@link LinearDictionaryEncoder}. + */ +public class TestLinearDictionaryEncoder { + + private final int VECTOR_LENGTH = 50; + + private final int DICTIONARY_LENGTH = 10; + + private BufferAllocator allocator; + + byte[] zero = "000".getBytes(StandardCharsets.UTF_8); + byte[] one = "111".getBytes(StandardCharsets.UTF_8); + byte[] two = "222".getBytes(StandardCharsets.UTF_8); + + byte[][] data = new byte[][]{zero, one, two}; + + @Before + public void prepare() { + allocator = new RootAllocator(1024 * 1024); + } + + @After + public void shutdown() { + allocator.close(); + } + + @Test + public void testEncodeAndDecode() { + Random random = new Random(); + try (VarCharVector rawVector = new VarCharVector("original vector", allocator); + IntVector encodedVector = new IntVector("encoded vector", allocator); + VarCharVector dictionary = new VarCharVector("dictionary", allocator)) { + + // set up dictionary + dictionary.allocateNew(); + for (int i = 0; i < DICTIONARY_LENGTH; i++) { + // encode "i" as i + dictionary.setSafe(i, String.valueOf(i).getBytes()); + } + dictionary.setValueCount(DICTIONARY_LENGTH); + + // set up raw vector + rawVector.allocateNew(10 * VECTOR_LENGTH, VECTOR_LENGTH); + for (int i = 0; i < VECTOR_LENGTH; i++) { + int val = (random.nextInt() & Integer.MAX_VALUE) % DICTIONARY_LENGTH; + rawVector.set(i, String.valueOf(val).getBytes()); + } + rawVector.setValueCount(VECTOR_LENGTH); + + LinearDictionaryEncoder encoder = + new LinearDictionaryEncoder<>(dictionary, false); + + // perform encoding + encodedVector.allocateNew(); + encoder.encode(rawVector, encodedVector); + + // verify encoding results + assertEquals(rawVector.getValueCount(), encodedVector.getValueCount()); + for (int i = 0; i < VECTOR_LENGTH; i++) { + assertArrayEquals(rawVector.get(i), String.valueOf(encodedVector.get(i)).getBytes()); + } + + // perform decoding + Dictionary dict = new Dictionary(dictionary, new DictionaryEncoding(1L, false, null)); + try (VarCharVector decodedVector = (VarCharVector) DictionaryEncoder.decode(encodedVector, dict)) { + + // verify decoding results + assertEquals(encodedVector.getValueCount(), decodedVector.getValueCount()); + for (int i = 0; i < VECTOR_LENGTH; i++) { + assertArrayEquals(String.valueOf(encodedVector.get(i)).getBytes(), decodedVector.get(i)); + } + } + } + } + + @Test + public void testEncodeAndDecodeWithNull() { + Random random = new Random(); + try (VarCharVector rawVector = new VarCharVector("original vector", allocator); + IntVector encodedVector = new IntVector("encoded vector", allocator); + VarCharVector dictionary = new VarCharVector("dictionary", allocator)) { + + // set up dictionary + dictionary.allocateNew(); + dictionary.setNull(0); + for (int i = 1; i < DICTIONARY_LENGTH; i++) { + // encode "i" as i + dictionary.setSafe(i, String.valueOf(i).getBytes()); + } + dictionary.setValueCount(DICTIONARY_LENGTH); + + // set up raw vector + rawVector.allocateNew(10 * VECTOR_LENGTH, VECTOR_LENGTH); + for (int i = 0; i < VECTOR_LENGTH; i++) { + if (i % 10 == 0) { + rawVector.setNull(i); + } else { + int val = (random.nextInt() & Integer.MAX_VALUE) % (DICTIONARY_LENGTH - 1) + 1; + rawVector.set(i, String.valueOf(val).getBytes()); + } + } + rawVector.setValueCount(VECTOR_LENGTH); + + LinearDictionaryEncoder encoder = + new LinearDictionaryEncoder<>(dictionary, true); + + // perform encoding + encodedVector.allocateNew(); + encoder.encode(rawVector, encodedVector); + + // verify encoding results + assertEquals(rawVector.getValueCount(), encodedVector.getValueCount()); + for (int i = 0; i < VECTOR_LENGTH; i++) { + if (i % 10 == 0) { + assertEquals(0, encodedVector.get(i)); + } else { + assertArrayEquals(rawVector.get(i), String.valueOf(encodedVector.get(i)).getBytes()); + } + } + + // perform decoding + Dictionary dict = new Dictionary(dictionary, new DictionaryEncoding(1L, false, null)); + try (VarCharVector decodedVector = (VarCharVector) DictionaryEncoder.decode(encodedVector, dict)) { + + // verify decoding results + assertEquals(encodedVector.getValueCount(), decodedVector.getValueCount()); + for (int i = 0; i < VECTOR_LENGTH; i++) { + if (i % 10 == 0) { + assertTrue(decodedVector.isNull(i)); + } else { + assertArrayEquals(String.valueOf(encodedVector.get(i)).getBytes(), decodedVector.get(i)); + } + } + } + } + } + + @Test + public void testEncodeNullWithoutNullInDictionary() { + try (VarCharVector rawVector = new VarCharVector("original vector", allocator); + IntVector encodedVector = new IntVector("encoded vector", allocator); + VarCharVector dictionary = new VarCharVector("dictionary", allocator)) { + + // set up dictionary, with no null in it. + dictionary.allocateNew(); + for (int i = 0; i < DICTIONARY_LENGTH; i++) { + // encode "i" as i + dictionary.setSafe(i, String.valueOf(i).getBytes()); + } + dictionary.setValueCount(DICTIONARY_LENGTH); + + // the vector to encode has a null inside. + rawVector.allocateNew(1); + rawVector.setNull(0); + rawVector.setValueCount(1); + + encodedVector.allocateNew(); + + LinearDictionaryEncoder encoder = + new LinearDictionaryEncoder<>(dictionary, true); + + // the encoder should encode null, but no null in the dictionary, + // so an exception should be thrown. + assertThrows(IllegalArgumentException.class, () -> { + encoder.encode(rawVector, encodedVector); + }); + } + } + + @Test + public void testEncodeStrings() { + // Create a new value vector + try (final VarCharVector vector = new VarCharVector("foo", allocator); + final IntVector encoded = new IntVector("encoded", allocator); + final VarCharVector dictionaryVector = new VarCharVector("dict", allocator)) { + + vector.allocateNew(512, 5); + encoded.allocateNew(); + + // set some values + vector.setSafe(0, zero, 0, zero.length); + vector.setSafe(1, one, 0, one.length); + vector.setSafe(2, one, 0, one.length); + vector.setSafe(3, two, 0, two.length); + vector.setSafe(4, zero, 0, zero.length); + vector.setValueCount(5); + + // set some dictionary values + dictionaryVector.allocateNew(512, 3); + dictionaryVector.setSafe(0, zero, 0, one.length); + dictionaryVector.setSafe(1, one, 0, two.length); + dictionaryVector.setSafe(2, two, 0, zero.length); + dictionaryVector.setValueCount(3); + + LinearDictionaryEncoder encoder = + new LinearDictionaryEncoder<>(dictionaryVector); + encoder.encode(vector, encoded); + + // verify indices + assertEquals(5, encoded.getValueCount()); + assertEquals(0, encoded.get(0)); + assertEquals(1, encoded.get(1)); + assertEquals(1, encoded.get(2)); + assertEquals(2, encoded.get(3)); + assertEquals(0, encoded.get(4)); + + // now run through the decoder and verify we get the original back + Dictionary dict = new Dictionary(dictionaryVector, new DictionaryEncoding(1L, false, null)); + try (VarCharVector decoded = (VarCharVector) DictionaryEncoder.decode(encoded, dict)) { + assertEquals(vector.getValueCount(), decoded.getValueCount()); + for (int i = 0; i < 5; i++) { + assertEquals(vector.getObject(i), decoded.getObject(i)); + } + } + } + } + + @Test + public void testEncodeLargeVector() { + // Create a new value vector + try (final VarCharVector vector = new VarCharVector("foo", allocator); + final IntVector encoded = new IntVector("encoded", allocator); + final VarCharVector dictionaryVector = new VarCharVector("dict", allocator)) { + vector.allocateNew(); + encoded.allocateNew(); + + int count = 10000; + + for (int i = 0; i < 10000; ++i) { + vector.setSafe(i, data[i % 3], 0, data[i % 3].length); + } + vector.setValueCount(count); + + dictionaryVector.allocateNew(512, 3); + dictionaryVector.setSafe(0, zero, 0, one.length); + dictionaryVector.setSafe(1, one, 0, two.length); + dictionaryVector.setSafe(2, two, 0, zero.length); + dictionaryVector.setValueCount(3); + + LinearDictionaryEncoder encoder = + new LinearDictionaryEncoder<>(dictionaryVector); + encoder.encode(vector, encoded); + + assertEquals(count, encoded.getValueCount()); + for (int i = 0; i < count; ++i) { + assertEquals(i % 3, encoded.get(i)); + } + + // now run through the decoder and verify we get the original back + Dictionary dict = new Dictionary(dictionaryVector, new DictionaryEncoding(1L, false, null)); + try (VarCharVector decoded = (VarCharVector) DictionaryEncoder.decode(encoded, dict)) { + assertEquals(vector.getClass(), decoded.getClass()); + assertEquals(vector.getValueCount(), decoded.getValueCount()); + for (int i = 0; i < count; ++i) { + assertEquals(vector.getObject(i), decoded.getObject(i)); + } + } + } + } + + @Test + public void testEncodeBinaryVector() { + // Create a new value vector + try (final VarBinaryVector vector = new VarBinaryVector("foo", allocator); + final VarBinaryVector dictionaryVector = new VarBinaryVector("dict", allocator); + final IntVector encoded = new IntVector("encoded", allocator)) { + vector.allocateNew(512, 5); + vector.allocateNew(); + encoded.allocateNew(); + + // set some values + vector.setSafe(0, zero, 0, zero.length); + vector.setSafe(1, one, 0, one.length); + vector.setSafe(2, one, 0, one.length); + vector.setSafe(3, two, 0, two.length); + vector.setSafe(4, zero, 0, zero.length); + vector.setValueCount(5); + + // set some dictionary values + dictionaryVector.allocateNew(512, 3); + dictionaryVector.setSafe(0, zero, 0, one.length); + dictionaryVector.setSafe(1, one, 0, two.length); + dictionaryVector.setSafe(2, two, 0, zero.length); + dictionaryVector.setValueCount(3); + + LinearDictionaryEncoder encoder = + new LinearDictionaryEncoder<>(dictionaryVector); + encoder.encode(vector, encoded); + + assertEquals(5, encoded.getValueCount()); + assertEquals(0, encoded.get(0)); + assertEquals(1, encoded.get(1)); + assertEquals(1, encoded.get(2)); + assertEquals(2, encoded.get(3)); + assertEquals(0, encoded.get(4)); + + // now run through the decoder and verify we get the original back + Dictionary dict = new Dictionary(dictionaryVector, new DictionaryEncoding(1L, false, null)); + try (VarBinaryVector decoded = (VarBinaryVector) DictionaryEncoder.decode(encoded, dict)) { + assertEquals(vector.getClass(), decoded.getClass()); + assertEquals(vector.getValueCount(), decoded.getValueCount()); + for (int i = 0; i < 5; i++) { + Assert.assertTrue(Arrays.equals(vector.getObject(i), decoded.getObject(i))); + } + } + } + } +} diff --git a/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/dictionary/TestSearchDictionaryEncoder.java b/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/dictionary/TestSearchDictionaryEncoder.java new file mode 100644 index 000000000..a156e987c --- /dev/null +++ b/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/dictionary/TestSearchDictionaryEncoder.java @@ -0,0 +1,357 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.algorithm.dictionary; + +import static junit.framework.TestCase.assertTrue; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.Random; + +import org.apache.arrow.algorithm.sort.DefaultVectorComparators; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.VarBinaryVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.DictionaryEncoder; +import org.apache.arrow.vector.types.pojo.DictionaryEncoding; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +/** + * Test cases for {@link SearchDictionaryEncoder}. + */ +public class TestSearchDictionaryEncoder { + + private final int VECTOR_LENGTH = 50; + + private final int DICTIONARY_LENGTH = 10; + + private BufferAllocator allocator; + + byte[] zero = "000".getBytes(StandardCharsets.UTF_8); + byte[] one = "111".getBytes(StandardCharsets.UTF_8); + byte[] two = "222".getBytes(StandardCharsets.UTF_8); + + byte[][] data = new byte[][]{zero, one, two}; + + @Before + public void prepare() { + allocator = new RootAllocator(1024 * 1024); + } + + @After + public void shutdown() { + allocator.close(); + } + + @Test + public void testEncodeAndDecode() { + Random random = new Random(); + try (VarCharVector rawVector = new VarCharVector("original vector", allocator); + IntVector encodedVector = new IntVector("encoded vector", allocator); + VarCharVector dictionary = new VarCharVector("dictionary", allocator)) { + + // set up dictionary + dictionary.allocateNew(); + for (int i = 0; i < DICTIONARY_LENGTH; i++) { + // encode "i" as i + dictionary.setSafe(i, String.valueOf(i).getBytes()); + } + dictionary.setValueCount(DICTIONARY_LENGTH); + + // set up raw vector + rawVector.allocateNew(10 * VECTOR_LENGTH, VECTOR_LENGTH); + for (int i = 0; i < VECTOR_LENGTH; i++) { + int val = (random.nextInt() & Integer.MAX_VALUE) % DICTIONARY_LENGTH; + rawVector.set(i, String.valueOf(val).getBytes()); + } + rawVector.setValueCount(VECTOR_LENGTH); + + SearchDictionaryEncoder encoder = + new SearchDictionaryEncoder<>( + dictionary, DefaultVectorComparators.createDefaultComparator(rawVector), false); + + // perform encoding + encodedVector.allocateNew(); + encoder.encode(rawVector, encodedVector); + + // verify encoding results + assertEquals(rawVector.getValueCount(), encodedVector.getValueCount()); + for (int i = 0; i < VECTOR_LENGTH; i++) { + assertArrayEquals(rawVector.get(i), String.valueOf(encodedVector.get(i)).getBytes()); + } + + // perform decoding + Dictionary dict = new Dictionary(dictionary, new DictionaryEncoding(1L, false, null)); + try (VarCharVector decodedVector = (VarCharVector) DictionaryEncoder.decode(encodedVector, dict)) { + + // verify decoding results + assertEquals(encodedVector.getValueCount(), decodedVector.getValueCount()); + for (int i = 0; i < VECTOR_LENGTH; i++) { + assertArrayEquals(String.valueOf(encodedVector.get(i)).getBytes(), decodedVector.get(i)); + } + } + } + } + + @Test + public void testEncodeAndDecodeWithNull() { + Random random = new Random(); + try (VarCharVector rawVector = new VarCharVector("original vector", allocator); + IntVector encodedVector = new IntVector("encoded vector", allocator); + VarCharVector dictionary = new VarCharVector("dictionary", allocator)) { + + // set up dictionary + dictionary.allocateNew(); + dictionary.setNull(0); + for (int i = 1; i < DICTIONARY_LENGTH; i++) { + // encode "i" as i + dictionary.setSafe(i, String.valueOf(i).getBytes()); + } + dictionary.setValueCount(DICTIONARY_LENGTH); + + // set up raw vector + rawVector.allocateNew(10 * VECTOR_LENGTH, VECTOR_LENGTH); + for (int i = 0; i < VECTOR_LENGTH; i++) { + if (i % 10 == 0) { + rawVector.setNull(i); + } else { + int val = (random.nextInt() & Integer.MAX_VALUE) % (DICTIONARY_LENGTH - 1) + 1; + rawVector.set(i, String.valueOf(val).getBytes()); + } + } + rawVector.setValueCount(VECTOR_LENGTH); + + SearchDictionaryEncoder encoder = + new SearchDictionaryEncoder<>( + dictionary, DefaultVectorComparators.createDefaultComparator(rawVector), true); + + // perform encoding + encodedVector.allocateNew(); + encoder.encode(rawVector, encodedVector); + + // verify encoding results + assertEquals(rawVector.getValueCount(), encodedVector.getValueCount()); + for (int i = 0; i < VECTOR_LENGTH; i++) { + if (i % 10 == 0) { + assertEquals(0, encodedVector.get(i)); + } else { + assertArrayEquals(rawVector.get(i), String.valueOf(encodedVector.get(i)).getBytes()); + } + } + + // perform decoding + Dictionary dict = new Dictionary(dictionary, new DictionaryEncoding(1L, false, null)); + try (VarCharVector decodedVector = (VarCharVector) DictionaryEncoder.decode(encodedVector, dict)) { + + // verify decoding results + assertEquals(encodedVector.getValueCount(), decodedVector.getValueCount()); + for (int i = 0; i < VECTOR_LENGTH; i++) { + if (i % 10 == 0) { + assertTrue(decodedVector.isNull(i)); + } else { + assertArrayEquals(String.valueOf(encodedVector.get(i)).getBytes(), decodedVector.get(i)); + } + } + } + } + } + + @Test + public void testEncodeNullWithoutNullInDictionary() { + try (VarCharVector rawVector = new VarCharVector("original vector", allocator); + IntVector encodedVector = new IntVector("encoded vector", allocator); + VarCharVector dictionary = new VarCharVector("dictionary", allocator)) { + + // set up dictionary, with no null in it. + dictionary.allocateNew(); + for (int i = 0; i < DICTIONARY_LENGTH; i++) { + // encode "i" as i + dictionary.setSafe(i, String.valueOf(i).getBytes()); + } + dictionary.setValueCount(DICTIONARY_LENGTH); + + // the vector to encode has a null inside. + rawVector.allocateNew(1); + rawVector.setNull(0); + rawVector.setValueCount(1); + + encodedVector.allocateNew(); + + SearchDictionaryEncoder encoder = + new SearchDictionaryEncoder<>( + dictionary, DefaultVectorComparators.createDefaultComparator(rawVector), true); + + // the encoder should encode null, but no null in the dictionary, + // so an exception should be thrown. + assertThrows(IllegalArgumentException.class, () -> { + encoder.encode(rawVector, encodedVector); + }); + } + } + + @Test + public void testEncodeStrings() { + // Create a new value vector + try (final VarCharVector vector = new VarCharVector("foo", allocator); + final IntVector encoded = new IntVector("encoded", allocator); + final VarCharVector dictionaryVector = new VarCharVector("dict", allocator)) { + + vector.allocateNew(512, 5); + encoded.allocateNew(); + + // set some values + vector.setSafe(0, zero, 0, zero.length); + vector.setSafe(1, one, 0, one.length); + vector.setSafe(2, one, 0, one.length); + vector.setSafe(3, two, 0, two.length); + vector.setSafe(4, zero, 0, zero.length); + vector.setValueCount(5); + + // set some dictionary values + dictionaryVector.allocateNew(512, 3); + dictionaryVector.setSafe(0, zero, 0, one.length); + dictionaryVector.setSafe(1, one, 0, two.length); + dictionaryVector.setSafe(2, two, 0, zero.length); + dictionaryVector.setValueCount(3); + + SearchDictionaryEncoder encoder = + new SearchDictionaryEncoder<>( + dictionaryVector, DefaultVectorComparators.createDefaultComparator(vector)); + encoder.encode(vector, encoded); + + // verify indices + assertEquals(5, encoded.getValueCount()); + assertEquals(0, encoded.get(0)); + assertEquals(1, encoded.get(1)); + assertEquals(1, encoded.get(2)); + assertEquals(2, encoded.get(3)); + assertEquals(0, encoded.get(4)); + + // now run through the decoder and verify we get the original back + Dictionary dict = new Dictionary(dictionaryVector, new DictionaryEncoding(1L, false, null)); + try (VarCharVector decoded = (VarCharVector) DictionaryEncoder.decode(encoded, dict)) { + assertEquals(vector.getValueCount(), decoded.getValueCount()); + for (int i = 0; i < 5; i++) { + assertEquals(vector.getObject(i), decoded.getObject(i)); + } + } + } + } + + @Test + public void testEncodeLargeVector() { + // Create a new value vector + try (final VarCharVector vector = new VarCharVector("foo", allocator); + final IntVector encoded = new IntVector("encoded", allocator); + final VarCharVector dictionaryVector = new VarCharVector("dict", allocator)) { + vector.allocateNew(); + encoded.allocateNew(); + + int count = 10000; + + for (int i = 0; i < 10000; ++i) { + vector.setSafe(i, data[i % 3], 0, data[i % 3].length); + } + vector.setValueCount(count); + + dictionaryVector.allocateNew(512, 3); + dictionaryVector.setSafe(0, zero, 0, one.length); + dictionaryVector.setSafe(1, one, 0, two.length); + dictionaryVector.setSafe(2, two, 0, zero.length); + dictionaryVector.setValueCount(3); + + SearchDictionaryEncoder encoder = + new SearchDictionaryEncoder<>( + dictionaryVector, DefaultVectorComparators.createDefaultComparator(vector)); + encoder.encode(vector, encoded); + + assertEquals(count, encoded.getValueCount()); + for (int i = 0; i < count; ++i) { + assertEquals(i % 3, encoded.get(i)); + } + + // now run through the decoder and verify we get the original back + Dictionary dict = new Dictionary(dictionaryVector, new DictionaryEncoding(1L, false, null)); + try (VarCharVector decoded = (VarCharVector) DictionaryEncoder.decode(encoded, dict)) { + assertEquals(vector.getClass(), decoded.getClass()); + assertEquals(vector.getValueCount(), decoded.getValueCount()); + for (int i = 0; i < count; ++i) { + assertEquals(vector.getObject(i), decoded.getObject(i)); + } + } + } + } + + @Test + public void testEncodeBinaryVector() { + // Create a new value vector + try (final VarBinaryVector vector = new VarBinaryVector("foo", allocator); + final VarBinaryVector dictionaryVector = new VarBinaryVector("dict", allocator); + final IntVector encoded = new IntVector("encoded", allocator)) { + vector.allocateNew(512, 5); + vector.allocateNew(); + encoded.allocateNew(); + + // set some values + vector.setSafe(0, zero, 0, zero.length); + vector.setSafe(1, one, 0, one.length); + vector.setSafe(2, one, 0, one.length); + vector.setSafe(3, two, 0, two.length); + vector.setSafe(4, zero, 0, zero.length); + vector.setValueCount(5); + + // set some dictionary values + dictionaryVector.allocateNew(512, 3); + dictionaryVector.setSafe(0, zero, 0, one.length); + dictionaryVector.setSafe(1, one, 0, two.length); + dictionaryVector.setSafe(2, two, 0, zero.length); + dictionaryVector.setValueCount(3); + + SearchDictionaryEncoder encoder = + new SearchDictionaryEncoder<>( + dictionaryVector, DefaultVectorComparators.createDefaultComparator(vector)); + encoder.encode(vector, encoded); + + assertEquals(5, encoded.getValueCount()); + assertEquals(0, encoded.get(0)); + assertEquals(1, encoded.get(1)); + assertEquals(1, encoded.get(2)); + assertEquals(2, encoded.get(3)); + assertEquals(0, encoded.get(4)); + + // now run through the decoder and verify we get the original back + Dictionary dict = new Dictionary(dictionaryVector, new DictionaryEncoding(1L, false, null)); + try (VarBinaryVector decoded = (VarBinaryVector) DictionaryEncoder.decode(encoded, dict)) { + assertEquals(vector.getClass(), decoded.getClass()); + assertEquals(vector.getValueCount(), decoded.getValueCount()); + for (int i = 0; i < 5; i++) { + Assert.assertTrue(Arrays.equals(vector.getObject(i), decoded.getObject(i))); + } + } + } + } +} diff --git a/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/dictionary/TestSearchTreeBasedDictionaryBuilder.java b/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/dictionary/TestSearchTreeBasedDictionaryBuilder.java new file mode 100644 index 000000000..d8e9edce8 --- /dev/null +++ b/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/dictionary/TestSearchTreeBasedDictionaryBuilder.java @@ -0,0 +1,221 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.algorithm.dictionary; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import org.apache.arrow.algorithm.sort.DefaultVectorComparators; +import org.apache.arrow.algorithm.sort.VectorValueComparator; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.VarCharVector; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +/** + * Test cases for {@link SearchTreeBasedDictionaryBuilder}. + */ +public class TestSearchTreeBasedDictionaryBuilder { + + private BufferAllocator allocator; + + @Before + public void prepare() { + allocator = new RootAllocator(1024 * 1024); + } + + @After + public void shutdown() { + allocator.close(); + } + + @Test + public void testBuildVariableWidthDictionaryWithNull() { + try (VarCharVector vec = new VarCharVector("", allocator); + VarCharVector dictionary = new VarCharVector("", allocator); + VarCharVector sortedDictionary = new VarCharVector("", allocator)) { + + vec.allocateNew(100, 10); + vec.setValueCount(10); + + dictionary.allocateNew(); + sortedDictionary.allocateNew(); + + // fill data + vec.set(0, "hello".getBytes()); + vec.set(1, "abc".getBytes()); + vec.setNull(2); + vec.set(3, "world".getBytes()); + vec.set(4, "12".getBytes()); + vec.set(5, "dictionary".getBytes()); + vec.setNull(6); + vec.set(7, "hello".getBytes()); + vec.set(8, "good".getBytes()); + vec.set(9, "abc".getBytes()); + + VectorValueComparator comparator = DefaultVectorComparators.createDefaultComparator(vec); + SearchTreeBasedDictionaryBuilder dictionaryBuilder = + new SearchTreeBasedDictionaryBuilder<>(dictionary, comparator, true); + + int result = dictionaryBuilder.addValues(vec); + + assertEquals(7, result); + assertEquals(7, dictionary.getValueCount()); + + dictionaryBuilder.populateSortedDictionary(sortedDictionary); + + assertTrue(sortedDictionary.isNull(0)); + assertEquals("12", new String(sortedDictionary.get(1))); + assertEquals("abc", new String(sortedDictionary.get(2))); + assertEquals("dictionary", new String(sortedDictionary.get(3))); + assertEquals("good", new String(sortedDictionary.get(4))); + assertEquals("hello", new String(sortedDictionary.get(5))); + assertEquals("world", new String(sortedDictionary.get(6))); + } + } + + @Test + public void testBuildVariableWidthDictionaryWithoutNull() { + try (VarCharVector vec = new VarCharVector("", allocator); + VarCharVector dictionary = new VarCharVector("", allocator); + VarCharVector sortedDictionary = new VarCharVector("", allocator)) { + + vec.allocateNew(100, 10); + vec.setValueCount(10); + + dictionary.allocateNew(); + sortedDictionary.allocateNew(); + + // fill data + vec.set(0, "hello".getBytes()); + vec.set(1, "abc".getBytes()); + vec.setNull(2); + vec.set(3, "world".getBytes()); + vec.set(4, "12".getBytes()); + vec.set(5, "dictionary".getBytes()); + vec.setNull(6); + vec.set(7, "hello".getBytes()); + vec.set(8, "good".getBytes()); + vec.set(9, "abc".getBytes()); + + VectorValueComparator comparator = DefaultVectorComparators.createDefaultComparator(vec); + SearchTreeBasedDictionaryBuilder dictionaryBuilder = + new SearchTreeBasedDictionaryBuilder<>(dictionary, comparator, false); + + int result = dictionaryBuilder.addValues(vec); + + assertEquals(6, result); + assertEquals(6, dictionary.getValueCount()); + + dictionaryBuilder.populateSortedDictionary(sortedDictionary); + + assertEquals("12", new String(sortedDictionary.get(0))); + assertEquals("abc", new String(sortedDictionary.get(1))); + assertEquals("dictionary", new String(sortedDictionary.get(2))); + assertEquals("good", new String(sortedDictionary.get(3))); + assertEquals("hello", new String(sortedDictionary.get(4))); + assertEquals("world", new String(sortedDictionary.get(5))); + } + } + + @Test + public void testBuildFixedWidthDictionaryWithNull() { + try (IntVector vec = new IntVector("", allocator); + IntVector dictionary = new IntVector("", allocator); + IntVector sortedDictionary = new IntVector("", allocator)) { + vec.allocateNew(10); + vec.setValueCount(10); + + dictionary.allocateNew(); + sortedDictionary.allocateNew(); + + // fill data + vec.set(0, 4); + vec.set(1, 8); + vec.set(2, 32); + vec.set(3, 8); + vec.set(4, 16); + vec.set(5, 32); + vec.setNull(6); + vec.set(7, 4); + vec.set(8, 4); + vec.setNull(9); + + VectorValueComparator comparator = DefaultVectorComparators.createDefaultComparator(vec); + SearchTreeBasedDictionaryBuilder dictionaryBuilder = + new SearchTreeBasedDictionaryBuilder<>(dictionary, comparator, true); + + int result = dictionaryBuilder.addValues(vec); + + assertEquals(5, result); + assertEquals(5, dictionary.getValueCount()); + + dictionaryBuilder.populateSortedDictionary(sortedDictionary); + + assertTrue(sortedDictionary.isNull(0)); + assertEquals(4, sortedDictionary.get(1)); + assertEquals(8, sortedDictionary.get(2)); + assertEquals(16, sortedDictionary.get(3)); + assertEquals(32, sortedDictionary.get(4)); + } + } + + @Test + public void testBuildFixedWidthDictionaryWithoutNull() { + try (IntVector vec = new IntVector("", allocator); + IntVector dictionary = new IntVector("", allocator); + IntVector sortedDictionary = new IntVector("", allocator)) { + vec.allocateNew(10); + vec.setValueCount(10); + + dictionary.allocateNew(); + sortedDictionary.allocateNew(); + + // fill data + vec.set(0, 4); + vec.set(1, 8); + vec.set(2, 32); + vec.set(3, 8); + vec.set(4, 16); + vec.set(5, 32); + vec.setNull(6); + vec.set(7, 4); + vec.set(8, 4); + vec.setNull(9); + + VectorValueComparator comparator = DefaultVectorComparators.createDefaultComparator(vec); + SearchTreeBasedDictionaryBuilder dictionaryBuilder = + new SearchTreeBasedDictionaryBuilder<>(dictionary, comparator, false); + + int result = dictionaryBuilder.addValues(vec); + + assertEquals(4, result); + assertEquals(4, dictionary.getValueCount()); + + dictionaryBuilder.populateSortedDictionary(sortedDictionary); + + assertEquals(4, sortedDictionary.get(0)); + assertEquals(8, sortedDictionary.get(1)); + assertEquals(16, sortedDictionary.get(2)); + assertEquals(32, sortedDictionary.get(3)); + } + } +} diff --git a/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/misc/TestPartialSumUtils.java b/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/misc/TestPartialSumUtils.java new file mode 100644 index 000000000..4e2d5900f --- /dev/null +++ b/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/misc/TestPartialSumUtils.java @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.algorithm.misc; + +import static org.junit.Assert.assertEquals; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.IntVector; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +/** + * Test cases for {@link PartialSumUtils}. + */ +public class TestPartialSumUtils { + + private static final int PARTIAL_SUM_VECTOR_LENGTH = 101; + + private static final int DELTA_VECTOR_LENGTH = 100; + + private BufferAllocator allocator; + + @Before + public void prepare() { + allocator = new RootAllocator(1024 * 1024); + } + + @After + public void shutdown() { + allocator.close(); + } + + @Test + public void testToPartialSumVector() { + try (IntVector delta = new IntVector("delta", allocator); + IntVector partialSum = new IntVector("partial sum", allocator)) { + delta.allocateNew(DELTA_VECTOR_LENGTH); + delta.setValueCount(DELTA_VECTOR_LENGTH); + + partialSum.allocateNew(PARTIAL_SUM_VECTOR_LENGTH); + + // populate delta vector + for (int i = 0; i < delta.getValueCount(); i++) { + delta.set(i, 3); + } + + final long sumBase = 10; + PartialSumUtils.toPartialSumVector(delta, partialSum, sumBase); + + // verify results + assertEquals(PARTIAL_SUM_VECTOR_LENGTH, partialSum.getValueCount()); + for (int i = 0; i < partialSum.getValueCount(); i++) { + assertEquals(i * 3 + sumBase, partialSum.get(i)); + } + } + } + + @Test + public void testToDeltaVector() { + try (IntVector partialSum = new IntVector("partial sum", allocator); + IntVector delta = new IntVector("delta", allocator)) { + partialSum.allocateNew(PARTIAL_SUM_VECTOR_LENGTH); + partialSum.setValueCount(PARTIAL_SUM_VECTOR_LENGTH); + + delta.allocateNew(DELTA_VECTOR_LENGTH); + + // populate delta vector + final int sumBase = 10; + for (int i = 0; i < partialSum.getValueCount(); i++) { + partialSum.set(i, sumBase + 3 * i); + } + + PartialSumUtils.toDeltaVector(partialSum, delta); + + // verify results + assertEquals(DELTA_VECTOR_LENGTH, delta.getValueCount()); + for (int i = 0; i < delta.getValueCount(); i++) { + assertEquals(3, delta.get(i)); + } + } + } + + @Test + public void testFindPositionInPartialSumVector() { + try (IntVector partialSum = new IntVector("partial sum", allocator)) { + partialSum.allocateNew(PARTIAL_SUM_VECTOR_LENGTH); + partialSum.setValueCount(PARTIAL_SUM_VECTOR_LENGTH); + + // populate delta vector + final int sumBase = 10; + for (int i = 0; i < partialSum.getValueCount(); i++) { + partialSum.set(i, sumBase + 3 * i); + } + + // search and verify results + for (int i = 0; i < PARTIAL_SUM_VECTOR_LENGTH - 1; i++) { + assertEquals(i, PartialSumUtils.findPositionInPartialSumVector(partialSum, sumBase + 3 * i + 1)); + } + } + } + + @Test + public void testFindPositionInPartialSumVectorNegative() { + try (IntVector partialSum = new IntVector("partial sum", allocator)) { + partialSum.allocateNew(PARTIAL_SUM_VECTOR_LENGTH); + partialSum.setValueCount(PARTIAL_SUM_VECTOR_LENGTH); + + // populate delta vector + final int sumBase = 10; + for (int i = 0; i < partialSum.getValueCount(); i++) { + partialSum.set(i, sumBase + 3 * i); + } + + // search and verify results + assertEquals(0, PartialSumUtils.findPositionInPartialSumVector(partialSum, sumBase)); + assertEquals(-1, PartialSumUtils.findPositionInPartialSumVector(partialSum, sumBase - 1)); + assertEquals(-1, PartialSumUtils.findPositionInPartialSumVector(partialSum, + sumBase + 3 * (PARTIAL_SUM_VECTOR_LENGTH - 1))); + } + } +} diff --git a/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/rank/TestVectorRank.java b/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/rank/TestVectorRank.java new file mode 100644 index 000000000..f372a809b --- /dev/null +++ b/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/rank/TestVectorRank.java @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.algorithm.rank; + +import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import org.apache.arrow.algorithm.sort.DefaultVectorComparators; +import org.apache.arrow.algorithm.sort.VectorValueComparator; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.VarCharVector; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +/** + * Test cases for {@link org.apache.arrow.algorithm.rank.VectorRank}. + */ +public class TestVectorRank { + + private BufferAllocator allocator; + + private static final int VECTOR_LENGTH = 10; + + @Before + public void prepare() { + allocator = new RootAllocator(1024 * 1024); + } + + @After + public void shutdown() { + allocator.close(); + } + + @Test + public void testFixedWidthRank() { + VectorRank rank = new VectorRank<>(allocator); + try (IntVector vector = new IntVector("int vec", allocator)) { + vector.allocateNew(VECTOR_LENGTH); + vector.setValueCount(VECTOR_LENGTH); + + vector.set(0, 1); + vector.set(1, 5); + vector.set(2, 3); + vector.set(3, 7); + vector.set(4, 9); + vector.set(5, 8); + vector.set(6, 2); + vector.set(7, 0); + vector.set(8, 4); + vector.set(9, 6); + + VectorValueComparator comparator = + DefaultVectorComparators.createDefaultComparator(vector); + assertEquals(7, rank.indexAtRank(vector, comparator, 0)); + assertEquals(0, rank.indexAtRank(vector, comparator, 1)); + assertEquals(6, rank.indexAtRank(vector, comparator, 2)); + assertEquals(2, rank.indexAtRank(vector, comparator, 3)); + assertEquals(8, rank.indexAtRank(vector, comparator, 4)); + assertEquals(1, rank.indexAtRank(vector, comparator, 5)); + assertEquals(9, rank.indexAtRank(vector, comparator, 6)); + assertEquals(3, rank.indexAtRank(vector, comparator, 7)); + assertEquals(5, rank.indexAtRank(vector, comparator, 8)); + assertEquals(4, rank.indexAtRank(vector, comparator, 9)); + } + } + + @Test + public void testVariableWidthRank() { + VectorRank rank = new VectorRank<>(allocator); + try (VarCharVector vector = new VarCharVector("varchar vec", allocator)) { + vector.allocateNew(VECTOR_LENGTH * 5, VECTOR_LENGTH); + vector.setValueCount(VECTOR_LENGTH); + + vector.set(0, String.valueOf(1).getBytes()); + vector.set(1, String.valueOf(5).getBytes()); + vector.set(2, String.valueOf(3).getBytes()); + vector.set(3, String.valueOf(7).getBytes()); + vector.set(4, String.valueOf(9).getBytes()); + vector.set(5, String.valueOf(8).getBytes()); + vector.set(6, String.valueOf(2).getBytes()); + vector.set(7, String.valueOf(0).getBytes()); + vector.set(8, String.valueOf(4).getBytes()); + vector.set(9, String.valueOf(6).getBytes()); + + VectorValueComparator comparator = + DefaultVectorComparators.createDefaultComparator(vector); + + assertEquals(7, rank.indexAtRank(vector, comparator, 0)); + assertEquals(0, rank.indexAtRank(vector, comparator, 1)); + assertEquals(6, rank.indexAtRank(vector, comparator, 2)); + assertEquals(2, rank.indexAtRank(vector, comparator, 3)); + assertEquals(8, rank.indexAtRank(vector, comparator, 4)); + assertEquals(1, rank.indexAtRank(vector, comparator, 5)); + assertEquals(9, rank.indexAtRank(vector, comparator, 6)); + assertEquals(3, rank.indexAtRank(vector, comparator, 7)); + assertEquals(5, rank.indexAtRank(vector, comparator, 8)); + assertEquals(4, rank.indexAtRank(vector, comparator, 9)); + } + } + + @Test + public void testRankNegative() { + VectorRank rank = new VectorRank<>(allocator); + try (IntVector vector = new IntVector("int vec", allocator)) { + vector.allocateNew(VECTOR_LENGTH); + vector.setValueCount(VECTOR_LENGTH); + + vector.set(0, 1); + vector.set(1, 5); + vector.set(2, 3); + vector.set(3, 7); + vector.set(4, 9); + vector.set(5, 8); + vector.set(6, 2); + vector.set(7, 0); + vector.set(8, 4); + vector.set(9, 6); + + VectorValueComparator comparator = + DefaultVectorComparators.createDefaultComparator(vector); + + assertThrows(IllegalArgumentException.class, () -> { + rank.indexAtRank(vector, comparator, VECTOR_LENGTH + 1); + }); + } + } +} diff --git a/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/search/TestParallelSearcher.java b/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/search/TestParallelSearcher.java new file mode 100644 index 000000000..767935aaa --- /dev/null +++ b/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/search/TestParallelSearcher.java @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.algorithm.search; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; + +import org.apache.arrow.algorithm.sort.DefaultVectorComparators; +import org.apache.arrow.algorithm.sort.VectorValueComparator; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.VarCharVector; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +/** + * Test cases for {@link ParallelSearcher}. + */ +@RunWith(Parameterized.class) +public class TestParallelSearcher { + + private enum ComparatorType { + EqualityComparator, + OrderingComparator; + } + + private static final int VECTOR_LENGTH = 10000; + + private final int threadCount; + + private BufferAllocator allocator; + + private ExecutorService threadPool; + + private final ComparatorType comparatorType; + + public TestParallelSearcher(ComparatorType comparatorType, int threadCount) { + this.comparatorType = comparatorType; + this.threadCount = threadCount; + } + + @Parameterized.Parameters(name = "comparator type = {0}, thread count = {1}") + public static Collection getComparatorName() { + List params = new ArrayList<>(); + int[] threadCounts = {1, 2, 5, 10, 20, 50}; + for (ComparatorType type : ComparatorType.values()) { + for (int count : threadCounts) { + params.add(new Object[] {type, count}); + } + } + return params; + } + + @Before + public void prepare() { + allocator = new RootAllocator(1024 * 1024); + threadPool = Executors.newFixedThreadPool(threadCount); + } + + @After + public void shutdown() { + allocator.close(); + threadPool.shutdown(); + } + + @Test + public void testParallelIntSearch() throws ExecutionException, InterruptedException { + try (IntVector targetVector = new IntVector("targetVector", allocator); + IntVector keyVector = new IntVector("keyVector", allocator)) { + targetVector.allocateNew(VECTOR_LENGTH); + keyVector.allocateNew(VECTOR_LENGTH); + + // if we are comparing elements using equality semantics, we do not need a comparator here. + VectorValueComparator comparator = comparatorType == ComparatorType.EqualityComparator ? null + : DefaultVectorComparators.createDefaultComparator(targetVector); + + for (int i = 0; i < VECTOR_LENGTH; i++) { + targetVector.set(i, i); + keyVector.set(i, i * 2); + } + targetVector.setValueCount(VECTOR_LENGTH); + keyVector.setValueCount(VECTOR_LENGTH); + + ParallelSearcher searcher = new ParallelSearcher<>(targetVector, threadPool, threadCount); + for (int i = 0; i < VECTOR_LENGTH; i++) { + int pos = comparator == null ? searcher.search(keyVector, i) : searcher.search(keyVector, i, comparator); + if (i * 2 < VECTOR_LENGTH) { + assertEquals(i * 2, pos); + } else { + assertEquals(-1, pos); + } + } + } + } + + @Test + public void testParallelStringSearch() throws ExecutionException, InterruptedException { + try (VarCharVector targetVector = new VarCharVector("targetVector", allocator); + VarCharVector keyVector = new VarCharVector("keyVector", allocator)) { + targetVector.allocateNew(VECTOR_LENGTH); + keyVector.allocateNew(VECTOR_LENGTH); + + // if we are comparing elements using equality semantics, we do not need a comparator here. + VectorValueComparator comparator = comparatorType == ComparatorType.EqualityComparator ? null + : DefaultVectorComparators.createDefaultComparator(targetVector); + + for (int i = 0; i < VECTOR_LENGTH; i++) { + targetVector.setSafe(i, String.valueOf(i).getBytes()); + keyVector.setSafe(i, String.valueOf(i * 2).getBytes()); + } + targetVector.setValueCount(VECTOR_LENGTH); + keyVector.setValueCount(VECTOR_LENGTH); + + ParallelSearcher searcher = new ParallelSearcher<>(targetVector, threadPool, threadCount); + for (int i = 0; i < VECTOR_LENGTH; i++) { + int pos = comparator == null ? searcher.search(keyVector, i) : searcher.search(keyVector, i, comparator); + if (i * 2 < VECTOR_LENGTH) { + assertEquals(i * 2, pos); + } else { + assertEquals(-1, pos); + } + } + } + } +} diff --git a/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/search/TestVectorRangeSearcher.java b/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/search/TestVectorRangeSearcher.java new file mode 100644 index 000000000..d7659dc4c --- /dev/null +++ b/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/search/TestVectorRangeSearcher.java @@ -0,0 +1,195 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.algorithm.search; + +import static org.junit.Assert.assertEquals; + +import java.util.Arrays; +import java.util.Collection; + +import org.apache.arrow.algorithm.sort.DefaultVectorComparators; +import org.apache.arrow.algorithm.sort.VectorValueComparator; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.IntVector; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +/** + * Test cases for {@link VectorRangeSearcher}. + */ +@RunWith(Parameterized.class) +public class TestVectorRangeSearcher { + + private BufferAllocator allocator; + + private int repeat; + + public TestVectorRangeSearcher(int repeat) { + this.repeat = repeat; + } + + @Before + public void prepare() { + allocator = new RootAllocator(1024 * 1024); + } + + @After + public void shutdown() { + allocator.close(); + } + + @Test + public void testGetLowerBounds() { + final int maxValue = 100; + try (IntVector intVector = new IntVector("int vec", allocator)) { + // allocate vector + intVector.allocateNew(maxValue * repeat); + intVector.setValueCount(maxValue * repeat); + + // prepare data in sorted order + // each value is repeated some times + for (int i = 0; i < maxValue; i++) { + for (int j = 0; j < repeat; j++) { + if (i == 0) { + intVector.setNull(i * repeat + j); + } else { + intVector.set(i * repeat + j, i); + } + } + } + + // do search + VectorValueComparator comparator = DefaultVectorComparators.createDefaultComparator(intVector); + for (int i = 0; i < maxValue; i++) { + int result = VectorRangeSearcher.getFirstMatch(intVector, comparator, intVector, i * repeat); + assertEquals(i * repeat, result); + } + } + } + + @Test + public void testGetLowerBoundsNegative() { + final int maxValue = 100; + try (IntVector intVector = new IntVector("int vec", allocator); + IntVector negVector = new IntVector("neg vec", allocator)) { + // allocate vector + intVector.allocateNew(maxValue * repeat); + intVector.setValueCount(maxValue * repeat); + + negVector.allocateNew(maxValue); + negVector.setValueCount(maxValue); + + // prepare data in sorted order + // each value is repeated some times + for (int i = 0; i < maxValue; i++) { + for (int j = 0; j < repeat; j++) { + if (i == 0) { + intVector.setNull(i * repeat + j); + } else { + intVector.set(i * repeat + j, i); + } + } + negVector.set(i, maxValue + i); + } + + // do search + VectorValueComparator comparator = DefaultVectorComparators.createDefaultComparator(intVector); + for (int i = 0; i < maxValue; i++) { + int result = VectorRangeSearcher.getFirstMatch(intVector, comparator, negVector, i); + assertEquals(-1, result); + } + } + } + + @Test + public void testGetUpperBounds() { + final int maxValue = 100; + try (IntVector intVector = new IntVector("int vec", allocator)) { + // allocate vector + intVector.allocateNew(maxValue * repeat); + intVector.setValueCount(maxValue * repeat); + + // prepare data in sorted order + // each value is repeated some times + for (int i = 0; i < maxValue; i++) { + for (int j = 0; j < repeat; j++) { + if (i == 0) { + intVector.setNull(i * repeat + j); + } else { + intVector.set(i * repeat + j, i); + } + } + } + + // do search + VectorValueComparator comparator = DefaultVectorComparators.createDefaultComparator(intVector); + for (int i = 0; i < maxValue; i++) { + int result = VectorRangeSearcher.getLastMatch(intVector, comparator, intVector, i * repeat); + assertEquals((i + 1) * repeat - 1, result); + } + } + } + + @Test + public void testGetUpperBoundsNegative() { + final int maxValue = 100; + try (IntVector intVector = new IntVector("int vec", allocator); + IntVector negVector = new IntVector("neg vec", allocator)) { + // allocate vector + intVector.allocateNew(maxValue * repeat); + intVector.setValueCount(maxValue * repeat); + + negVector.allocateNew(maxValue); + negVector.setValueCount(maxValue); + + // prepare data in sorted order + // each value is repeated some times + for (int i = 0; i < maxValue; i++) { + for (int j = 0; j < repeat; j++) { + if (i == 0) { + intVector.setNull(i * repeat + j); + } else { + intVector.set(i * repeat + j, i); + } + } + negVector.set(i, maxValue + i); + } + + // do search + VectorValueComparator comparator = DefaultVectorComparators.createDefaultComparator(intVector); + for (int i = 0; i < maxValue; i++) { + int result = VectorRangeSearcher.getLastMatch(intVector, comparator, negVector, i); + assertEquals(-1, result); + } + } + } + + @Parameterized.Parameters(name = "repeat = {0}") + public static Collection getRepeat() { + return Arrays.asList( + new Object[]{1}, + new Object[]{2}, + new Object[]{5}, + new Object[]{10} + ); + } +} diff --git a/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/search/TestVectorSearcher.java b/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/search/TestVectorSearcher.java new file mode 100644 index 000000000..2847ddbb8 --- /dev/null +++ b/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/search/TestVectorSearcher.java @@ -0,0 +1,299 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.algorithm.search; + +import static org.apache.arrow.vector.complex.BaseRepeatedValueVector.OFFSET_WIDTH; +import static org.junit.Assert.assertEquals; + +import org.apache.arrow.algorithm.sort.DefaultVectorComparators; +import org.apache.arrow.algorithm.sort.VectorValueComparator; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.BaseVariableWidthVector; +import org.apache.arrow.vector.BitVectorHelper; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.types.Types; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +/** + * Test cases for {@link org.apache.arrow.algorithm.search.VectorSearcher}. + */ +public class TestVectorSearcher { + + private final int VECTOR_LENGTH = 100; + + private BufferAllocator allocator; + + @Before + public void prepare() { + allocator = new RootAllocator(1024 * 1024); + } + + @After + public void shutdown() { + allocator.close(); + } + + @Test + public void testBinarySearchInt() { + try (IntVector rawVector = new IntVector("", allocator); + IntVector negVector = new IntVector("", allocator)) { + rawVector.allocateNew(VECTOR_LENGTH); + rawVector.setValueCount(VECTOR_LENGTH); + negVector.allocateNew(1); + negVector.setValueCount(1); + + // prepare data in sorted order + for (int i = 0; i < VECTOR_LENGTH; i++) { + if (i == 0) { + rawVector.setNull(i); + } else { + rawVector.set(i, i); + } + } + negVector.set(0, -333); + + // do search + VectorValueComparator comparator = + DefaultVectorComparators.createDefaultComparator(rawVector); + for (int i = 0; i < VECTOR_LENGTH; i++) { + int result = VectorSearcher.binarySearch(rawVector, comparator, rawVector, i); + assertEquals(i, result); + } + + // negative case + assertEquals(-1, VectorSearcher.binarySearch(rawVector, comparator, negVector, 0)); + } + } + + @Test + public void testLinearSearchInt() { + try (IntVector rawVector = new IntVector("", allocator); + IntVector negVector = new IntVector("", allocator)) { + rawVector.allocateNew(VECTOR_LENGTH); + rawVector.setValueCount(VECTOR_LENGTH); + negVector.allocateNew(1); + negVector.setValueCount(1); + + // prepare data in sorted order + for (int i = 0; i < VECTOR_LENGTH; i++) { + if (i == 0) { + rawVector.setNull(i); + } else { + rawVector.set(i, i); + } + } + negVector.set(0, -333); + + // do search + VectorValueComparator comparator = + DefaultVectorComparators.createDefaultComparator(rawVector); + for (int i = 0; i < VECTOR_LENGTH; i++) { + int result = VectorSearcher.linearSearch(rawVector, comparator, rawVector, i); + assertEquals(i, result); + } + + // negative case + assertEquals(-1, VectorSearcher.linearSearch(rawVector, comparator, negVector, 0)); + } + } + + @Test + public void testBinarySearchVarChar() { + try (VarCharVector rawVector = new VarCharVector("", allocator); + VarCharVector negVector = new VarCharVector("", allocator)) { + rawVector.allocateNew(VECTOR_LENGTH * 16, VECTOR_LENGTH); + rawVector.setValueCount(VECTOR_LENGTH); + negVector.allocateNew(VECTOR_LENGTH, 1); + negVector.setValueCount(1); + + byte[] content = new byte[2]; + + // prepare data in sorted order + for (int i = 0; i < VECTOR_LENGTH; i++) { + if (i == 0) { + rawVector.setNull(i); + } else { + int q = i / 10; + int r = i % 10; + + content[0] = (byte) ('a' + q); + content[1] = (byte) r; + rawVector.set(i, content); + } + } + negVector.set(0, "abcd".getBytes()); + + // do search + VectorValueComparator comparator = + DefaultVectorComparators.createDefaultComparator(rawVector); + for (int i = 0; i < VECTOR_LENGTH; i++) { + int result = VectorSearcher.binarySearch(rawVector, comparator, rawVector, i); + assertEquals(i, result); + } + + // negative case + assertEquals(-1, VectorSearcher.binarySearch(rawVector, comparator, negVector, 0)); + } + } + + @Test + public void testLinearSearchVarChar() { + try (VarCharVector rawVector = new VarCharVector("", allocator); + VarCharVector negVector = new VarCharVector("", allocator)) { + rawVector.allocateNew(VECTOR_LENGTH * 16, VECTOR_LENGTH); + rawVector.setValueCount(VECTOR_LENGTH); + negVector.allocateNew(VECTOR_LENGTH, 1); + negVector.setValueCount(1); + + byte[] content = new byte[2]; + + // prepare data in sorted order + for (int i = 0; i < VECTOR_LENGTH; i++) { + if (i == 0) { + rawVector.setNull(i); + } else { + int q = i / 10; + int r = i % 10; + + content[0] = (byte) ('a' + q); + content[1] = (byte) r; + rawVector.set(i, content); + } + } + negVector.set(0, "abcd".getBytes()); + + // do search + VectorValueComparator comparator = + DefaultVectorComparators.createDefaultComparator(rawVector); + for (int i = 0; i < VECTOR_LENGTH; i++) { + int result = VectorSearcher.linearSearch(rawVector, comparator, rawVector, i); + assertEquals(i, result); + } + + // negative case + assertEquals(-1, VectorSearcher.linearSearch(rawVector, comparator, negVector, 0)); + } + } + + private ListVector createListVector() { + final int innerCount = 100; + final int outerCount = 10; + final int listLength = innerCount / outerCount; + + ListVector listVector = ListVector.empty("list vector", allocator); + + Types.MinorType type = Types.MinorType.INT; + listVector.addOrGetVector(FieldType.nullable(type.getType())); + + listVector.allocateNew(); + + IntVector dataVector = (IntVector) listVector.getDataVector(); + + for (int i = 0; i < innerCount; i++) { + dataVector.set(i, i); + } + dataVector.setValueCount(innerCount); + + for (int i = 0; i < outerCount; i++) { + BitVectorHelper.setBit(listVector.getValidityBuffer(), i); + listVector.getOffsetBuffer().setInt(i * OFFSET_WIDTH, i * listLength); + listVector.getOffsetBuffer().setInt((i + 1) * OFFSET_WIDTH, (i + 1) * listLength); + } + listVector.setLastSet(outerCount - 1); + listVector.setValueCount(outerCount); + + return listVector; + } + + private ListVector createNegativeListVector() { + final int innerCount = 100; + final int outerCount = 10; + final int listLength = innerCount / outerCount; + + ListVector listVector = ListVector.empty("list vector", allocator); + + Types.MinorType type = Types.MinorType.INT; + listVector.addOrGetVector(FieldType.nullable(type.getType())); + + listVector.allocateNew(); + + IntVector dataVector = (IntVector) listVector.getDataVector(); + + for (int i = 0; i < innerCount; i++) { + dataVector.set(i, i + 1000); + } + dataVector.setValueCount(innerCount); + + for (int i = 0; i < outerCount; i++) { + BitVectorHelper.setBit(listVector.getValidityBuffer(), i); + listVector.getOffsetBuffer().setInt(i * OFFSET_WIDTH, i * listLength); + listVector.getOffsetBuffer().setInt((i + 1) * OFFSET_WIDTH, (i + 1) * listLength); + } + listVector.setValueCount(outerCount); + + return listVector; + } + + @Test + public void testBinarySearchList() { + try (ListVector rawVector = createListVector(); + ListVector negVector = createNegativeListVector()) { + + // do search + VectorValueComparator comparator = + DefaultVectorComparators.createDefaultComparator(rawVector); + for (int i = 0; i < rawVector.getValueCount(); i++) { + int result = VectorSearcher.binarySearch(rawVector, comparator, rawVector, i); + assertEquals(i, result); + } + + // negative case + for (int i = 0; i < rawVector.getValueCount(); i++) { + int result = VectorSearcher.binarySearch(rawVector, comparator, negVector, i); + assertEquals(-1, result); + } + } + } + + @Test + public void testLinearSearchList() { + try (ListVector rawVector = createListVector(); + ListVector negVector = createNegativeListVector()) { + + // do search + VectorValueComparator comparator = + DefaultVectorComparators.createDefaultComparator(rawVector); + for (int i = 0; i < rawVector.getValueCount(); i++) { + int result = VectorSearcher.linearSearch(rawVector, comparator, rawVector, i); + assertEquals(i, result); + } + + // negative case + for (int i = 0; i < rawVector.getValueCount(); i++) { + int result = VectorSearcher.linearSearch(rawVector, comparator, negVector, i); + assertEquals(-1, result); + } + } + } +} diff --git a/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestCompositeVectorComparator.java b/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestCompositeVectorComparator.java new file mode 100644 index 000000000..cac9933cc --- /dev/null +++ b/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestCompositeVectorComparator.java @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.algorithm.sort; + +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.Arrays; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +/** + * Test cases for {@link CompositeVectorComparator}. + */ +public class TestCompositeVectorComparator { + + private BufferAllocator allocator; + + @Before + public void prepare() { + allocator = new RootAllocator(1024 * 1024); + } + + @After + public void shutdown() { + allocator.close(); + } + + @Test + public void testCompareVectorSchemaRoot() { + final int vectorLength = 10; + IntVector intVec1 = new IntVector("int1", allocator); + VarCharVector strVec1 = new VarCharVector("str1", allocator); + + IntVector intVec2 = new IntVector("int2", allocator); + VarCharVector strVec2 = new VarCharVector("str2", allocator); + + try (VectorSchemaRoot batch1 = new VectorSchemaRoot(Arrays.asList(intVec1, strVec1)); + VectorSchemaRoot batch2 = new VectorSchemaRoot(Arrays.asList(intVec2, strVec2))) { + + intVec1.allocateNew(vectorLength); + strVec1.allocateNew(vectorLength * 10, vectorLength); + intVec2.allocateNew(vectorLength); + strVec2.allocateNew(vectorLength * 10, vectorLength); + + for (int i = 0; i < vectorLength; i++) { + intVec1.set(i, i); + strVec1.set(i, new String("a" + i).getBytes()); + intVec2.set(i, i); + strVec2.set(i, new String("a5").getBytes()); + } + + VectorValueComparator innerComparator1 = + DefaultVectorComparators.createDefaultComparator(intVec1); + innerComparator1.attachVectors(intVec1, intVec2); + VectorValueComparator innerComparator2 = + DefaultVectorComparators.createDefaultComparator(strVec1); + innerComparator2.attachVectors(strVec1, strVec2); + + VectorValueComparator comparator = new CompositeVectorComparator( + new VectorValueComparator[]{innerComparator1, innerComparator2} + ); + + // verify results + + // both elements are equal, the result is equal + assertTrue(comparator.compare(5, 5) == 0); + + // the first element being equal, the second is smaller, and the result is smaller + assertTrue(comparator.compare(1, 1) < 0); + assertTrue(comparator.compare(2, 2) < 0); + assertTrue(comparator.compare(3, 3) < 0); + + // the first element being equal, the second is larger, and the result is larger + assertTrue(comparator.compare(7, 7) > 0); + assertTrue(comparator.compare(8, 8) > 0); + assertTrue(comparator.compare(9, 9) > 0); + + // the first element is smaller, the result is always smaller + assertTrue(comparator.compare(1, 2) < 0); + assertTrue(comparator.compare(3, 7) < 0); + assertTrue(comparator.compare(4, 9) < 0); + + // the first element is larger, the result is always larger + assertTrue(comparator.compare(2, 0) > 0); + assertTrue(comparator.compare(8, 7) > 0); + assertTrue(comparator.compare(4, 1) > 0); + } + } +} diff --git a/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestDefaultVectorComparator.java b/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestDefaultVectorComparator.java new file mode 100644 index 000000000..2fbf598bf --- /dev/null +++ b/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestDefaultVectorComparator.java @@ -0,0 +1,393 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.algorithm.sort; + +import static org.apache.arrow.vector.complex.BaseRepeatedValueVector.OFFSET_WIDTH; +import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.SmallIntVector; +import org.apache.arrow.vector.TinyIntVector; +import org.apache.arrow.vector.UInt1Vector; +import org.apache.arrow.vector.UInt2Vector; +import org.apache.arrow.vector.UInt4Vector; +import org.apache.arrow.vector.UInt8Vector; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.testing.ValueVectorDataPopulator; +import org.apache.arrow.vector.types.Types; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +/** + * Test cases for {@link DefaultVectorComparators}. + */ +public class TestDefaultVectorComparator { + + private BufferAllocator allocator; + + @Before + public void prepare() { + allocator = new RootAllocator(1024 * 1024); + } + + @After + public void shutdown() { + allocator.close(); + } + + private ListVector createListVector(int count) { + ListVector listVector = ListVector.empty("list vector", allocator); + Types.MinorType type = Types.MinorType.INT; + listVector.addOrGetVector(FieldType.nullable(type.getType())); + listVector.allocateNew(); + + IntVector dataVector = (IntVector) listVector.getDataVector(); + + for (int i = 0; i < count; i++) { + dataVector.set(i, i); + } + dataVector.setValueCount(count); + + listVector.setNotNull(0); + + listVector.getOffsetBuffer().setInt(0, 0); + listVector.getOffsetBuffer().setInt(OFFSET_WIDTH, count); + + listVector.setLastSet(0); + listVector.setValueCount(1); + + return listVector; + } + + @Test + public void testCompareLists() { + try (ListVector listVector1 = createListVector(10); + ListVector listVector2 = createListVector(11)) { + VectorValueComparator comparator = + DefaultVectorComparators.createDefaultComparator(listVector1); + comparator.attachVectors(listVector1, listVector2); + + // prefix is smaller + assertTrue(comparator.compare(0, 0) < 0); + } + + try (ListVector listVector1 = createListVector(11); + ListVector listVector2 = createListVector(11)) { + ((IntVector) listVector2.getDataVector()).set(10, 110); + + VectorValueComparator comparator = + DefaultVectorComparators.createDefaultComparator(listVector1); + comparator.attachVectors(listVector1, listVector2); + + // breaking tie by the last element + assertTrue(comparator.compare(0, 0) < 0); + } + + try (ListVector listVector1 = createListVector(10); + ListVector listVector2 = createListVector(10)) { + + VectorValueComparator comparator = + DefaultVectorComparators.createDefaultComparator(listVector1); + comparator.attachVectors(listVector1, listVector2); + + // list vector elements equal + assertTrue(comparator.compare(0, 0) == 0); + } + } + + @Test + public void testCopiedComparatorForLists() { + for (int i = 1; i < 10; i++) { + for (int j = 1; j < 10; j++) { + try (ListVector listVector1 = createListVector(10); + ListVector listVector2 = createListVector(11)) { + VectorValueComparator comparator = + DefaultVectorComparators.createDefaultComparator(listVector1); + comparator.attachVectors(listVector1, listVector2); + + VectorValueComparator copyComparator = comparator.createNew(); + copyComparator.attachVectors(listVector1, listVector2); + + assertEquals(comparator.compare(0, 0), copyComparator.compare(0, 0)); + } + } + } + } + + @Test + public void testCompareUInt1() { + try (UInt1Vector vec = new UInt1Vector("", allocator)) { + vec.allocateNew(10); + vec.setValueCount(10); + + vec.setNull(0); + vec.set(1, -2); + vec.set(2, -1); + vec.set(3, 0); + vec.set(4, 1); + vec.set(5, 2); + vec.set(6, -2); + vec.setNull(7); + vec.set(8, Byte.MAX_VALUE); + vec.set(9, Byte.MIN_VALUE); + + VectorValueComparator comparator = + DefaultVectorComparators.createDefaultComparator(vec); + comparator.attachVector(vec); + + assertTrue(comparator.compare(0, 1) < 0); + assertTrue(comparator.compare(1, 2) < 0); + assertTrue(comparator.compare(1, 3) > 0); + assertTrue(comparator.compare(2, 5) > 0); + assertTrue(comparator.compare(4, 5) < 0); + assertTrue(comparator.compare(1, 6) == 0); + assertTrue(comparator.compare(0, 7) == 0); + assertTrue(comparator.compare(8, 9) < 0); + assertTrue(comparator.compare(4, 8) < 0); + assertTrue(comparator.compare(5, 9) < 0); + assertTrue(comparator.compare(2, 9) > 0); + } + } + + @Test + public void testCompareUInt2() { + try (UInt2Vector vec = new UInt2Vector("", allocator)) { + vec.allocateNew(10); + + ValueVectorDataPopulator.setVector( + vec, null, (char) -2, (char) -1, (char) 0, (char) 1, (char) 2, (char) -2, null, + '\u7FFF', // value for the max 16-byte signed integer + '\u8000' // value for the min 16-byte signed integer + ); + + VectorValueComparator comparator = + DefaultVectorComparators.createDefaultComparator(vec); + comparator.attachVector(vec); + + assertTrue(comparator.compare(0, 1) < 0); + assertTrue(comparator.compare(1, 2) < 0); + assertTrue(comparator.compare(1, 3) > 0); + assertTrue(comparator.compare(2, 5) > 0); + assertTrue(comparator.compare(4, 5) < 0); + assertTrue(comparator.compare(1, 6) == 0); + assertTrue(comparator.compare(0, 7) == 0); + assertTrue(comparator.compare(8, 9) < 0); + assertTrue(comparator.compare(4, 8) < 0); + assertTrue(comparator.compare(5, 9) < 0); + assertTrue(comparator.compare(2, 9) > 0); + } + } + + @Test + public void testCompareUInt4() { + try (UInt4Vector vec = new UInt4Vector("", allocator)) { + vec.allocateNew(10); + vec.setValueCount(10); + + vec.setNull(0); + vec.set(1, -2); + vec.set(2, -1); + vec.set(3, 0); + vec.set(4, 1); + vec.set(5, 2); + vec.set(6, -2); + vec.setNull(7); + vec.set(8, Integer.MAX_VALUE); + vec.set(9, Integer.MIN_VALUE); + + VectorValueComparator comparator = + DefaultVectorComparators.createDefaultComparator(vec); + comparator.attachVector(vec); + + assertTrue(comparator.compare(0, 1) < 0); + assertTrue(comparator.compare(1, 2) < 0); + assertTrue(comparator.compare(1, 3) > 0); + assertTrue(comparator.compare(2, 5) > 0); + assertTrue(comparator.compare(4, 5) < 0); + assertTrue(comparator.compare(1, 6) == 0); + assertTrue(comparator.compare(0, 7) == 0); + assertTrue(comparator.compare(8, 9) < 0); + assertTrue(comparator.compare(4, 8) < 0); + assertTrue(comparator.compare(5, 9) < 0); + assertTrue(comparator.compare(2, 9) > 0); + } + } + + @Test + public void testCompareUInt8() { + try (UInt8Vector vec = new UInt8Vector("", allocator)) { + vec.allocateNew(10); + vec.setValueCount(10); + + vec.setNull(0); + vec.set(1, -2); + vec.set(2, -1); + vec.set(3, 0); + vec.set(4, 1); + vec.set(5, 2); + vec.set(6, -2); + vec.setNull(7); + vec.set(8, Long.MAX_VALUE); + vec.set(9, Long.MIN_VALUE); + + VectorValueComparator comparator = + DefaultVectorComparators.createDefaultComparator(vec); + comparator.attachVector(vec); + + assertTrue(comparator.compare(0, 1) < 0); + assertTrue(comparator.compare(1, 2) < 0); + assertTrue(comparator.compare(1, 3) > 0); + assertTrue(comparator.compare(2, 5) > 0); + assertTrue(comparator.compare(4, 5) < 0); + assertTrue(comparator.compare(1, 6) == 0); + assertTrue(comparator.compare(0, 7) == 0); + assertTrue(comparator.compare(8, 9) < 0); + assertTrue(comparator.compare(4, 8) < 0); + assertTrue(comparator.compare(5, 9) < 0); + assertTrue(comparator.compare(2, 9) > 0); + } + } + + @Test + public void testCompareLong() { + try (BigIntVector vec = new BigIntVector("", allocator)) { + vec.allocateNew(8); + ValueVectorDataPopulator.setVector( + vec, -1L, 0L, 1L, null, 1L, 5L, Long.MIN_VALUE + 1L, Long.MAX_VALUE); + + VectorValueComparator comparator = + DefaultVectorComparators.createDefaultComparator(vec); + comparator.attachVector(vec); + + assertTrue(comparator.compare(0, 1) < 0); + assertTrue(comparator.compare(0, 2) < 0); + assertTrue(comparator.compare(2, 1) > 0); + + // test equality + assertTrue(comparator.compare(5, 5) == 0); + assertTrue(comparator.compare(2, 4) == 0); + + // null first + assertTrue(comparator.compare(3, 4) < 0); + assertTrue(comparator.compare(5, 3) > 0); + + // potential overflow + assertTrue(comparator.compare(6, 7) < 0); + assertTrue(comparator.compare(7, 6) > 0); + assertTrue(comparator.compare(7, 7) == 0); + } + } + + @Test + public void testCompareInt() { + try (IntVector vec = new IntVector("", allocator)) { + vec.allocateNew(8); + ValueVectorDataPopulator.setVector( + vec, -1, 0, 1, null, 1, 5, Integer.MIN_VALUE + 1, Integer.MAX_VALUE); + + VectorValueComparator comparator = + DefaultVectorComparators.createDefaultComparator(vec); + comparator.attachVector(vec); + + assertTrue(comparator.compare(0, 1) < 0); + assertTrue(comparator.compare(0, 2) < 0); + assertTrue(comparator.compare(2, 1) > 0); + + // test equality + assertTrue(comparator.compare(5, 5) == 0); + assertTrue(comparator.compare(2, 4) == 0); + + // null first + assertTrue(comparator.compare(3, 4) < 0); + assertTrue(comparator.compare(5, 3) > 0); + + // potential overflow + assertTrue(comparator.compare(6, 7) < 0); + assertTrue(comparator.compare(7, 6) > 0); + assertTrue(comparator.compare(7, 7) == 0); + } + } + + @Test + public void testCompareShort() { + try (SmallIntVector vec = new SmallIntVector("", allocator)) { + vec.allocateNew(8); + ValueVectorDataPopulator.setVector( + vec, (short) -1, (short) 0, (short) 1, null, (short) 1, (short) 5, + (short) (Short.MIN_VALUE + 1), Short.MAX_VALUE); + + VectorValueComparator comparator = + DefaultVectorComparators.createDefaultComparator(vec); + comparator.attachVector(vec); + + assertTrue(comparator.compare(0, 1) < 0); + assertTrue(comparator.compare(0, 2) < 0); + assertTrue(comparator.compare(2, 1) > 0); + + // test equality + assertTrue(comparator.compare(5, 5) == 0); + assertTrue(comparator.compare(2, 4) == 0); + + // null first + assertTrue(comparator.compare(3, 4) < 0); + assertTrue(comparator.compare(5, 3) > 0); + + // potential overflow + assertTrue(comparator.compare(6, 7) < 0); + assertTrue(comparator.compare(7, 6) > 0); + assertTrue(comparator.compare(7, 7) == 0); + } + } + + @Test + public void testCompareByte() { + try (TinyIntVector vec = new TinyIntVector("", allocator)) { + vec.allocateNew(8); + ValueVectorDataPopulator.setVector( + vec, (byte) -1, (byte) 0, (byte) 1, null, (byte) 1, (byte) 5, + (byte) (Byte.MIN_VALUE + 1), Byte.MAX_VALUE); + + VectorValueComparator comparator = + DefaultVectorComparators.createDefaultComparator(vec); + comparator.attachVector(vec); + + assertTrue(comparator.compare(0, 1) < 0); + assertTrue(comparator.compare(0, 2) < 0); + assertTrue(comparator.compare(2, 1) > 0); + + // test equality + assertTrue(comparator.compare(5, 5) == 0); + assertTrue(comparator.compare(2, 4) == 0); + + // null first + assertTrue(comparator.compare(3, 4) < 0); + assertTrue(comparator.compare(5, 3) > 0); + + // potential overflow + assertTrue(comparator.compare(6, 7) < 0); + assertTrue(comparator.compare(7, 6) > 0); + assertTrue(comparator.compare(7, 7) == 0); + } + } +} diff --git a/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestFixedWidthInPlaceVectorSorter.java b/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestFixedWidthInPlaceVectorSorter.java new file mode 100644 index 000000000..91ef52017 --- /dev/null +++ b/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestFixedWidthInPlaceVectorSorter.java @@ -0,0 +1,240 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.algorithm.sort; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; + +import java.util.stream.IntStream; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.testing.ValueVectorDataPopulator; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +/** + * Test cases for {@link FixedWidthInPlaceVectorSorter}. + */ +public class TestFixedWidthInPlaceVectorSorter { + + private BufferAllocator allocator; + + @Before + public void prepare() { + allocator = new RootAllocator(1024 * 1024); + } + + @After + public void shutdown() { + allocator.close(); + } + + @Test + public void testSortInt() { + try (IntVector vec = new IntVector("", allocator)) { + vec.allocateNew(10); + vec.setValueCount(10); + + // fill data to sort + vec.set(0, 10); + vec.set(1, 8); + vec.setNull(2); + vec.set(3, 10); + vec.set(4, 12); + vec.set(5, 17); + vec.setNull(6); + vec.set(7, 23); + vec.set(8, 35); + vec.set(9, 2); + + // sort the vector + FixedWidthInPlaceVectorSorter sorter = new FixedWidthInPlaceVectorSorter(); + VectorValueComparator comparator = DefaultVectorComparators.createDefaultComparator(vec); + + sorter.sortInPlace(vec, comparator); + + // verify results + Assert.assertEquals(10, vec.getValueCount()); + + assertTrue(vec.isNull(0)); + assertTrue(vec.isNull(1)); + Assert.assertEquals(2, vec.get(2)); + Assert.assertEquals(8, vec.get(3)); + Assert.assertEquals(10, vec.get(4)); + Assert.assertEquals(10, vec.get(5)); + Assert.assertEquals(12, vec.get(6)); + Assert.assertEquals(17, vec.get(7)); + Assert.assertEquals(23, vec.get(8)); + Assert.assertEquals(35, vec.get(9)); + } + } + + /** + * Tests the worst case for quick sort. + * It may cause stack overflow if the algorithm is implemented as a recursive algorithm. + */ + @Test + public void testSortLargeIncreasingInt() { + final int vectorLength = 20000; + try (IntVector vec = new IntVector("", allocator)) { + vec.allocateNew(vectorLength); + vec.setValueCount(vectorLength); + + // fill data to sort + for (int i = 0; i < vectorLength; i++) { + vec.set(i, i); + } + + // sort the vector + FixedWidthInPlaceVectorSorter sorter = new FixedWidthInPlaceVectorSorter(); + VectorValueComparator comparator = DefaultVectorComparators.createDefaultComparator(vec); + + sorter.sortInPlace(vec, comparator); + + // verify results + Assert.assertEquals(vectorLength, vec.getValueCount()); + + for (int i = 0; i < vectorLength; i++) { + Assert.assertEquals(i, vec.get(i)); + } + } + } + + @Test + public void testChoosePivot() { + final int vectorLength = 100; + try (IntVector vec = new IntVector("", allocator)) { + vec.allocateNew(vectorLength); + + // the vector is sorted, so the pivot should be in the middle + for (int i = 0; i < vectorLength; i++) { + vec.set(i, i * 100); + } + vec.setValueCount(vectorLength); + + FixedWidthInPlaceVectorSorter sorter = new FixedWidthInPlaceVectorSorter(); + VectorValueComparator comparator = DefaultVectorComparators.createDefaultComparator(vec); + + try (IntVector pivotBuffer = (IntVector) vec.getField().createVector(allocator)) { + // setup internal data structures + pivotBuffer.allocateNew(1); + sorter.pivotBuffer = pivotBuffer; + sorter.comparator = comparator; + sorter.vec = vec; + comparator.attachVectors(vec, pivotBuffer); + + int low = 5; + int high = 6; + int pivotValue = vec.get(low); + assertTrue(high - low + 1 < FixedWidthInPlaceVectorSorter.STOP_CHOOSING_PIVOT_THRESHOLD); + + // the range is small enough, so the pivot is simply selected as the low value + sorter.choosePivot(low, high); + assertEquals(pivotValue, vec.get(low)); + + low = 30; + high = 80; + pivotValue = vec.get((low + high) / 2); + assertTrue(high - low + 1 >= FixedWidthInPlaceVectorSorter.STOP_CHOOSING_PIVOT_THRESHOLD); + + // the range is large enough, so the median is selected as the pivot + sorter.choosePivot(low, high); + assertEquals(pivotValue, vec.get(low)); + } + } + } + + /** + * Evaluates choosing pivot for all possible permutations of 3 numbers. + */ + @Test + public void testChoosePivotAllPermutes() { + try (IntVector vec = new IntVector("", allocator)) { + vec.allocateNew(3); + + FixedWidthInPlaceVectorSorter sorter = new FixedWidthInPlaceVectorSorter(); + VectorValueComparator comparator = DefaultVectorComparators.createDefaultComparator(vec); + + try (IntVector pivotBuffer = (IntVector) vec.getField().createVector(allocator)) { + // setup internal data structures + pivotBuffer.allocateNew(1); + sorter.pivotBuffer = pivotBuffer; + sorter.comparator = comparator; + sorter.vec = vec; + comparator.attachVectors(vec, pivotBuffer); + + int low = 0; + int high = 2; + + ValueVectorDataPopulator.setVector(vec, 11, 22, 33); + sorter.choosePivot(low, high); + assertEquals(22, vec.get(0)); + + ValueVectorDataPopulator.setVector(vec, 11, 33, 22); + sorter.choosePivot(low, high); + assertEquals(22, vec.get(0)); + + ValueVectorDataPopulator.setVector(vec, 22, 11, 33); + sorter.choosePivot(low, high); + assertEquals(22, vec.get(0)); + + ValueVectorDataPopulator.setVector(vec, 22, 33, 11); + sorter.choosePivot(low, high); + assertEquals(22, vec.get(0)); + + ValueVectorDataPopulator.setVector(vec, 33, 11, 22); + sorter.choosePivot(low, high); + assertEquals(22, vec.get(0)); + + ValueVectorDataPopulator.setVector(vec, 33, 22, 11); + sorter.choosePivot(low, high); + assertEquals(22, vec.get(0)); + } + } + } + + @Test + public void testSortInt2() { + try (IntVector vector = new IntVector("vector", allocator)) { + ValueVectorDataPopulator.setVector(vector, + 0, 1, 2, 3, 4, 5, 30, 31, 32, 33, + 34, 35, 60, 61, 62, 63, 64, 65, 6, 7, + 8, 9, 10, 11, 36, 37, 38, 39, 40, 41, + 66, 67, 68, 69, 70, 71); + + FixedWidthInPlaceVectorSorter sorter = new FixedWidthInPlaceVectorSorter(); + VectorValueComparator comparator = DefaultVectorComparators.createDefaultComparator(vector); + + sorter.sortInPlace(vector, comparator); + + int[] actual = new int[vector.getValueCount()]; + IntStream.range(0, vector.getValueCount()).forEach( + i -> actual[i] = vector.get(i)); + + assertArrayEquals( + new int[]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, + 11, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, + 40, 41, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71}, actual); + } + } +} diff --git a/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestFixedWidthOutOfPlaceVectorSorter.java b/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestFixedWidthOutOfPlaceVectorSorter.java new file mode 100644 index 000000000..e3701f1d8 --- /dev/null +++ b/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestFixedWidthOutOfPlaceVectorSorter.java @@ -0,0 +1,365 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.algorithm.sort; + +import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; + +import java.util.stream.IntStream; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.Float4Vector; +import org.apache.arrow.vector.Float8Vector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.SmallIntVector; +import org.apache.arrow.vector.TinyIntVector; +import org.apache.arrow.vector.testing.ValueVectorDataPopulator; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +/** + * Test cases for {@link FixedWidthOutOfPlaceVectorSorter}. + */ +public class TestFixedWidthOutOfPlaceVectorSorter { + + private BufferAllocator allocator; + + @Before + public void prepare() { + allocator = new RootAllocator(1024 * 1024); + } + + @After + public void shutdown() { + allocator.close(); + } + + @Test + public void testSortByte() { + try (TinyIntVector vec = new TinyIntVector("", allocator)) { + vec.allocateNew(10); + vec.setValueCount(10); + + // fill data to sort + vec.set(0, 10); + vec.set(1, 8); + vec.setNull(2); + vec.set(3, 10); + vec.set(4, 12); + vec.set(5, 17); + vec.setNull(6); + vec.set(7, 23); + vec.set(8, 35); + vec.set(9, 2); + + // sort the vector + FixedWidthOutOfPlaceVectorSorter sorter = new FixedWidthOutOfPlaceVectorSorter(); + VectorValueComparator comparator = DefaultVectorComparators.createDefaultComparator(vec); + + TinyIntVector sortedVec = + (TinyIntVector) vec.getField().getFieldType().createNewSingleVector("", allocator, null); + sortedVec.allocateNew(vec.getValueCount()); + sortedVec.setValueCount(vec.getValueCount()); + + sorter.sortOutOfPlace(vec, sortedVec, comparator); + + // verify results + Assert.assertEquals(vec.getValueCount(), sortedVec.getValueCount()); + + assertTrue(sortedVec.isNull(0)); + assertTrue(sortedVec.isNull(1)); + Assert.assertEquals((byte) 2, sortedVec.get(2)); + Assert.assertEquals((byte) 8, sortedVec.get(3)); + Assert.assertEquals((byte) 10, sortedVec.get(4)); + Assert.assertEquals((byte) 10, sortedVec.get(5)); + Assert.assertEquals((byte) 12, sortedVec.get(6)); + Assert.assertEquals((byte) 17, sortedVec.get(7)); + Assert.assertEquals((byte) 23, sortedVec.get(8)); + Assert.assertEquals((byte) 35, sortedVec.get(9)); + + sortedVec.close(); + } + } + + @Test + public void testSortShort() { + try (SmallIntVector vec = new SmallIntVector("", allocator)) { + vec.allocateNew(10); + vec.setValueCount(10); + + // fill data to sort + vec.set(0, 10); + vec.set(1, 8); + vec.setNull(2); + vec.set(3, 10); + vec.set(4, 12); + vec.set(5, 17); + vec.setNull(6); + vec.set(7, 23); + vec.set(8, 35); + vec.set(9, 2); + + // sort the vector + FixedWidthOutOfPlaceVectorSorter sorter = new FixedWidthOutOfPlaceVectorSorter(); + VectorValueComparator comparator = DefaultVectorComparators.createDefaultComparator(vec); + + SmallIntVector sortedVec = + (SmallIntVector) vec.getField().getFieldType().createNewSingleVector("", allocator, null); + sortedVec.allocateNew(vec.getValueCount()); + sortedVec.setValueCount(vec.getValueCount()); + + sorter.sortOutOfPlace(vec, sortedVec, comparator); + + // verify results + Assert.assertEquals(vec.getValueCount(), sortedVec.getValueCount()); + + assertTrue(sortedVec.isNull(0)); + assertTrue(sortedVec.isNull(1)); + Assert.assertEquals((short) 2, sortedVec.get(2)); + Assert.assertEquals((short) 8, sortedVec.get(3)); + Assert.assertEquals((short) 10, sortedVec.get(4)); + Assert.assertEquals((short) 10, sortedVec.get(5)); + Assert.assertEquals((short) 12, sortedVec.get(6)); + Assert.assertEquals((short) 17, sortedVec.get(7)); + Assert.assertEquals((short) 23, sortedVec.get(8)); + Assert.assertEquals((short) 35, sortedVec.get(9)); + + sortedVec.close(); + } + } + + @Test + public void testSortInt() { + try (IntVector vec = new IntVector("", allocator)) { + vec.allocateNew(10); + vec.setValueCount(10); + + // fill data to sort + vec.set(0, 10); + vec.set(1, 8); + vec.setNull(2); + vec.set(3, 10); + vec.set(4, 12); + vec.set(5, 17); + vec.setNull(6); + vec.set(7, 23); + vec.set(8, 35); + vec.set(9, 2); + + // sort the vector + FixedWidthOutOfPlaceVectorSorter sorter = new FixedWidthOutOfPlaceVectorSorter(); + VectorValueComparator comparator = DefaultVectorComparators.createDefaultComparator(vec); + + IntVector sortedVec = (IntVector) vec.getField().getFieldType().createNewSingleVector("", allocator, null); + sortedVec.allocateNew(vec.getValueCount()); + sortedVec.setValueCount(vec.getValueCount()); + + sorter.sortOutOfPlace(vec, sortedVec, comparator); + + // verify results + Assert.assertEquals(vec.getValueCount(), sortedVec.getValueCount()); + + assertTrue(sortedVec.isNull(0)); + assertTrue(sortedVec.isNull(1)); + Assert.assertEquals(2, sortedVec.get(2)); + Assert.assertEquals(8, sortedVec.get(3)); + Assert.assertEquals(10, sortedVec.get(4)); + Assert.assertEquals(10, sortedVec.get(5)); + Assert.assertEquals(12, sortedVec.get(6)); + Assert.assertEquals(17, sortedVec.get(7)); + Assert.assertEquals(23, sortedVec.get(8)); + Assert.assertEquals(35, sortedVec.get(9)); + + sortedVec.close(); + } + } + + @Test + public void testSortLong() { + try (BigIntVector vec = new BigIntVector("", allocator)) { + vec.allocateNew(10); + vec.setValueCount(10); + + // fill data to sort + vec.set(0, 10L); + vec.set(1, 8L); + vec.setNull(2); + vec.set(3, 10L); + vec.set(4, 12L); + vec.set(5, 17L); + vec.setNull(6); + vec.set(7, 23L); + vec.set(8, 1L << 35L); + vec.set(9, 2L); + + // sort the vector + FixedWidthOutOfPlaceVectorSorter sorter = new FixedWidthOutOfPlaceVectorSorter(); + VectorValueComparator comparator = DefaultVectorComparators.createDefaultComparator(vec); + + BigIntVector sortedVec = (BigIntVector) vec.getField().getFieldType().createNewSingleVector("", allocator, null); + sortedVec.allocateNew(vec.getValueCount()); + sortedVec.setValueCount(vec.getValueCount()); + + sorter.sortOutOfPlace(vec, sortedVec, comparator); + + // verify results + Assert.assertEquals(vec.getValueCount(), sortedVec.getValueCount()); + + assertTrue(sortedVec.isNull(0)); + assertTrue(sortedVec.isNull(1)); + Assert.assertEquals(2L, sortedVec.get(2)); + Assert.assertEquals(8L, sortedVec.get(3)); + Assert.assertEquals(10L, sortedVec.get(4)); + Assert.assertEquals(10L, sortedVec.get(5)); + Assert.assertEquals(12L, sortedVec.get(6)); + Assert.assertEquals(17L, sortedVec.get(7)); + Assert.assertEquals(23L, sortedVec.get(8)); + Assert.assertEquals(1L << 35L, sortedVec.get(9)); + + sortedVec.close(); + } + } + + @Test + public void testSortFloat() { + try (Float4Vector vec = new Float4Vector("", allocator)) { + vec.allocateNew(10); + vec.setValueCount(10); + + // fill data to sort + vec.set(0, 10f); + vec.set(1, 8f); + vec.setNull(2); + vec.set(3, 10f); + vec.set(4, 12f); + vec.set(5, 17f); + vec.setNull(6); + vec.set(7, 23f); + vec.set(8, Float.NaN); + vec.set(9, 2f); + + // sort the vector + FixedWidthOutOfPlaceVectorSorter sorter = new FixedWidthOutOfPlaceVectorSorter(); + VectorValueComparator comparator = DefaultVectorComparators.createDefaultComparator(vec); + + Float4Vector sortedVec = (Float4Vector) vec.getField().getFieldType().createNewSingleVector("", allocator, null); + sortedVec.allocateNew(vec.getValueCount()); + sortedVec.setValueCount(vec.getValueCount()); + + sorter.sortOutOfPlace(vec, sortedVec, comparator); + + // verify results + Assert.assertEquals(vec.getValueCount(), sortedVec.getValueCount()); + + assertTrue(sortedVec.isNull(0)); + assertTrue(sortedVec.isNull(1)); + Assert.assertEquals(2f, sortedVec.get(2), 0f); + Assert.assertEquals(8f, sortedVec.get(3), 0f); + Assert.assertEquals(10f, sortedVec.get(4), 0f); + Assert.assertEquals(10f, sortedVec.get(5), 0f); + Assert.assertEquals(12f, sortedVec.get(6), 0f); + Assert.assertEquals(17f, sortedVec.get(7), 0f); + Assert.assertEquals(23f, sortedVec.get(8), 0f); + Assert.assertEquals(Float.NaN, sortedVec.get(9), 0f); + + sortedVec.close(); + } + } + + @Test + public void testSortDouble() { + try (Float8Vector vec = new Float8Vector("", allocator)) { + vec.allocateNew(10); + vec.setValueCount(10); + + // fill data to sort + vec.set(0, 10); + vec.set(1, 8); + vec.setNull(2); + vec.set(3, 10); + vec.set(4, 12); + vec.set(5, 17); + vec.setNull(6); + vec.set(7, Double.NaN); + vec.set(8, 35); + vec.set(9, 2); + + // sort the vector + FixedWidthOutOfPlaceVectorSorter sorter = new FixedWidthOutOfPlaceVectorSorter(); + VectorValueComparator comparator = DefaultVectorComparators.createDefaultComparator(vec); + + Float8Vector sortedVec = (Float8Vector) vec.getField().getFieldType().createNewSingleVector("", allocator, null); + sortedVec.allocateNew(vec.getValueCount()); + sortedVec.setValueCount(vec.getValueCount()); + + sorter.sortOutOfPlace(vec, sortedVec, comparator); + + // verify results + Assert.assertEquals(vec.getValueCount(), sortedVec.getValueCount()); + + assertTrue(sortedVec.isNull(0)); + assertTrue(sortedVec.isNull(1)); + Assert.assertEquals(2, sortedVec.get(2), 0); + Assert.assertEquals(8, sortedVec.get(3), 0); + Assert.assertEquals(10, sortedVec.get(4), 0); + Assert.assertEquals(10, sortedVec.get(5), 0); + Assert.assertEquals(12, sortedVec.get(6), 0); + Assert.assertEquals(17, sortedVec.get(7), 0); + Assert.assertEquals(35, sortedVec.get(8), 0); + Assert.assertEquals(Double.NaN, sortedVec.get(9), 0); + + sortedVec.close(); + } + } + + @Test + public void testSortInt2() { + try (IntVector vec = new IntVector("", allocator)) { + ValueVectorDataPopulator.setVector(vec, + 0, 1, 2, 3, 4, 5, 30, 31, 32, 33, + 34, 35, 60, 61, 62, 63, 64, 65, 6, 7, + 8, 9, 10, 11, 36, 37, 38, 39, 40, 41, + 66, 67, 68, 69, 70, 71); + + // sort the vector + FixedWidthOutOfPlaceVectorSorter sorter = new FixedWidthOutOfPlaceVectorSorter(); + VectorValueComparator comparator = DefaultVectorComparators.createDefaultComparator(vec); + + try (IntVector sortedVec = (IntVector) vec.getField().getFieldType().createNewSingleVector("", allocator, null)) { + sortedVec.allocateNew(vec.getValueCount()); + sortedVec.setValueCount(vec.getValueCount()); + + sorter.sortOutOfPlace(vec, sortedVec, comparator); + + // verify results + int[] actual = new int[sortedVec.getValueCount()]; + IntStream.range(0, sortedVec.getValueCount()).forEach( + i -> actual[i] = sortedVec.get(i)); + + assertArrayEquals( + new int[]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, + 11, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, + 40, 41, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71}, actual); + } + } + } +} diff --git a/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestFixedWidthSorting.java b/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestFixedWidthSorting.java new file mode 100644 index 000000000..ba2a341bf --- /dev/null +++ b/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestFixedWidthSorting.java @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.algorithm.sort; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.function.Function; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.BaseFixedWidthVector; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.Float4Vector; +import org.apache.arrow.vector.Float8Vector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.SmallIntVector; +import org.apache.arrow.vector.TinyIntVector; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +/** + * Test sorting fixed width vectors with random data. + */ +@RunWith(Parameterized.class) +public class TestFixedWidthSorting> { + + static final int[] VECTOR_LENGTHS = new int[] {2, 5, 10, 50, 100, 1000, 3000}; + + static final double[] NULL_FRACTIONS = {0, 0.1, 0.3, 0.5, 0.7, 0.9, 1}; + + private final int length; + + private final double nullFraction; + + private final boolean inPlace; + + private final Function vectorGenerator; + + private final TestSortingUtil.DataGenerator dataGenerator; + + private BufferAllocator allocator; + + @Before + public void prepare() { + allocator = new RootAllocator(Integer.MAX_VALUE); + } + + @After + public void shutdown() { + allocator.close(); + } + + public TestFixedWidthSorting( + int length, double nullFraction, boolean inPlace, String desc, + Function vectorGenerator, TestSortingUtil.DataGenerator dataGenerator) { + this.length = length; + this.nullFraction = nullFraction; + this.inPlace = inPlace; + this.vectorGenerator = vectorGenerator; + this.dataGenerator = dataGenerator; + } + + @Test + public void testSort() { + if (inPlace) { + sortInPlace(); + } else { + sortOutOfPlace(); + } + } + + void sortInPlace() { + try (V vector = vectorGenerator.apply(allocator)) { + U[] array = dataGenerator.populate(vector, length, nullFraction); + TestSortingUtil.sortArray(array); + + FixedWidthInPlaceVectorSorter sorter = new FixedWidthInPlaceVectorSorter(); + VectorValueComparator comparator = DefaultVectorComparators.createDefaultComparator(vector); + + sorter.sortInPlace(vector, comparator); + + TestSortingUtil.verifyResults(vector, array); + } + } + + void sortOutOfPlace() { + try (V vector = vectorGenerator.apply(allocator)) { + U[] array = dataGenerator.populate(vector, length, nullFraction); + TestSortingUtil.sortArray(array); + + // sort the vector + FixedWidthOutOfPlaceVectorSorter sorter = new FixedWidthOutOfPlaceVectorSorter(); + VectorValueComparator comparator = DefaultVectorComparators.createDefaultComparator(vector); + + try (V sortedVec = (V) vector.getField().getFieldType().createNewSingleVector("", allocator, null)) { + sortedVec.allocateNew(vector.getValueCount()); + sortedVec.setValueCount(vector.getValueCount()); + + sorter.sortOutOfPlace(vector, sortedVec, comparator); + + // verify results + TestSortingUtil.verifyResults(sortedVec, array); + } + } + } + + @Parameterized.Parameters(name = "length = {0}, null fraction = {1}, in place = {2}, vector = {3}") + public static Collection getParameters() { + List params = new ArrayList<>(); + for (int length : VECTOR_LENGTHS) { + for (double nullFrac : NULL_FRACTIONS) { + for (boolean inPlace : new boolean[] {true, false}) { + params.add(new Object[] { + length, nullFrac, inPlace, "TinyIntVector", + (Function) (allocator -> new TinyIntVector("vector", allocator)), + TestSortingUtil.TINY_INT_GENERATOR + }); + + params.add(new Object[] { + length, nullFrac, inPlace, "SmallIntVector", + (Function) (allocator -> new SmallIntVector("vector", allocator)), + TestSortingUtil.SMALL_INT_GENERATOR + }); + + params.add(new Object[] { + length, nullFrac, inPlace, "IntVector", + (Function) (allocator -> new IntVector("vector", allocator)), + TestSortingUtil.INT_GENERATOR + }); + + params.add(new Object[] { + length, nullFrac, inPlace, "BigIntVector", + (Function) (allocator -> new BigIntVector("vector", allocator)), + TestSortingUtil.LONG_GENERATOR + }); + + params.add(new Object[] { + length, nullFrac, inPlace, "Float4Vector", + (Function) (allocator -> new Float4Vector("vector", allocator)), + TestSortingUtil.FLOAT_GENERATOR + }); + + params.add(new Object[] { + length, nullFrac, inPlace, "Float8Vector", + (Function) (allocator -> new Float8Vector("vector", allocator)), + TestSortingUtil.DOUBLE_GENERATOR + }); + } + } + } + return params; + } +} diff --git a/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestIndexSorter.java b/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestIndexSorter.java new file mode 100644 index 000000000..99e22f8bd --- /dev/null +++ b/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestIndexSorter.java @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.algorithm.sort; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.testing.ValueVectorDataPopulator; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +/** + * Test cases for {@link IndexSorter}. + */ +public class TestIndexSorter { + + private BufferAllocator allocator; + + @Before + public void prepare() { + allocator = new RootAllocator(1024 * 1024); + } + + @After + public void shutdown() { + allocator.close(); + } + + @Test + public void testIndexSort() { + try (IntVector vec = new IntVector("", allocator)) { + vec.allocateNew(10); + vec.setValueCount(10); + + // fill data to sort + ValueVectorDataPopulator.setVector(vec, 11, 8, 33, 10, 12, 17, null, 23, 35, 2); + + // sort the index + IndexSorter indexSorter = new IndexSorter<>(); + DefaultVectorComparators.IntComparator intComparator = new DefaultVectorComparators.IntComparator(); + intComparator.attachVector(vec); + + IntVector indices = new IntVector("", allocator); + indices.setValueCount(10); + indexSorter.sort(vec, indices, intComparator); + + int[] expected = new int[]{6, 9, 1, 3, 0, 4, 5, 7, 2, 8}; + + for (int i = 0; i < expected.length; i++) { + assertTrue(!indices.isNull(i)); + assertEquals(expected[i], indices.get(i)); + } + indices.close(); + } + } + + /** + * Tests the worst case for quick sort. + * It may cause stack overflow if the algorithm is implemented as a recursive algorithm. + */ + @Test + public void testSortLargeIncreasingInt() { + final int vectorLength = 20000; + try (IntVector vec = new IntVector("", allocator)) { + vec.allocateNew(vectorLength); + vec.setValueCount(vectorLength); + + // fill data to sort + for (int i = 0; i < vectorLength; i++) { + vec.set(i, i); + } + + // sort the vector + IndexSorter indexSorter = new IndexSorter<>(); + DefaultVectorComparators.IntComparator intComparator = new DefaultVectorComparators.IntComparator(); + intComparator.attachVector(vec); + + try (IntVector indices = new IntVector("", allocator)) { + indices.setValueCount(vectorLength); + indexSorter.sort(vec, indices, intComparator); + + for (int i = 0; i < vectorLength; i++) { + assertTrue(!indices.isNull(i)); + assertEquals(i, indices.get(i)); + } + } + } + } + + @Test + public void testChoosePivot() { + final int vectorLength = 100; + try (IntVector vec = new IntVector("vector", allocator); + IntVector indices = new IntVector("indices", allocator)) { + vec.allocateNew(vectorLength); + indices.allocateNew(vectorLength); + + // the vector is sorted, so the pivot should be in the middle + for (int i = 0; i < vectorLength; i++) { + vec.set(i, i * 100); + indices.set(i, i); + } + vec.setValueCount(vectorLength); + indices.setValueCount(vectorLength); + + VectorValueComparator comparator = DefaultVectorComparators.createDefaultComparator(vec); + + // setup internal data structures + comparator.attachVector(vec); + + int low = 5; + int high = 6; + assertTrue(high - low + 1 < FixedWidthInPlaceVectorSorter.STOP_CHOOSING_PIVOT_THRESHOLD); + + // the range is small enough, so the pivot is simply selected as the low value + int pivotIndex = IndexSorter.choosePivot(low, high, indices, comparator); + assertEquals(pivotIndex, low); + assertEquals(pivotIndex, indices.get(low)); + + low = 30; + high = 80; + assertTrue(high - low + 1 >= FixedWidthInPlaceVectorSorter.STOP_CHOOSING_PIVOT_THRESHOLD); + + // the range is large enough, so the median is selected as the pivot + pivotIndex = IndexSorter.choosePivot(low, high, indices, comparator); + assertEquals(pivotIndex, (low + high) / 2); + assertEquals(pivotIndex, indices.get(low)); + } + } + + /** + * Evaluates choosing pivot for all possible permutations of 3 numbers. + */ + @Test + public void testChoosePivotAllPermutes() { + try (IntVector vec = new IntVector("vector", allocator); + IntVector indices = new IntVector("indices", allocator)) { + vec.allocateNew(); + indices.allocateNew(); + + VectorValueComparator comparator = DefaultVectorComparators.createDefaultComparator(vec); + + // setup internal data structures + comparator.attachVector(vec); + int low = 0; + int high = 2; + + // test all the 6 permutations of 3 numbers + ValueVectorDataPopulator.setVector(indices, 0, 1, 2); + ValueVectorDataPopulator.setVector(vec, 11, 22, 33); + int pivotIndex = IndexSorter.choosePivot(low, high, indices, comparator); + assertEquals(1, pivotIndex); + assertEquals(1, indices.get(low)); + + ValueVectorDataPopulator.setVector(indices, 0, 1, 2); + ValueVectorDataPopulator.setVector(vec, 11, 33, 22); + pivotIndex = IndexSorter.choosePivot(low, high, indices, comparator); + assertEquals(2, pivotIndex); + assertEquals(2, indices.get(low)); + + ValueVectorDataPopulator.setVector(indices, 0, 1, 2); + ValueVectorDataPopulator.setVector(vec, 22, 11, 33); + pivotIndex = IndexSorter.choosePivot(low, high, indices, comparator); + assertEquals(0, pivotIndex); + assertEquals(0, indices.get(low)); + + ValueVectorDataPopulator.setVector(indices, 0, 1, 2); + ValueVectorDataPopulator.setVector(vec, 22, 33, 11); + pivotIndex = IndexSorter.choosePivot(low, high, indices, comparator); + assertEquals(0, pivotIndex); + assertEquals(0, indices.get(low)); + + ValueVectorDataPopulator.setVector(indices, 0, 1, 2); + ValueVectorDataPopulator.setVector(vec, 33, 11, 22); + pivotIndex = IndexSorter.choosePivot(low, high, indices, comparator); + assertEquals(2, pivotIndex); + assertEquals(2, indices.get(low)); + + ValueVectorDataPopulator.setVector(indices, 0, 1, 2); + ValueVectorDataPopulator.setVector(vec, 33, 22, 11); + pivotIndex = IndexSorter.choosePivot(low, high, indices, comparator); + assertEquals(1, pivotIndex); + assertEquals(1, indices.get(low)); + } + } +} diff --git a/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestInsertionSorter.java b/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestInsertionSorter.java new file mode 100644 index 000000000..ba9c42913 --- /dev/null +++ b/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestInsertionSorter.java @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.algorithm.sort; + +import static org.junit.Assert.assertFalse; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.testing.ValueVectorDataPopulator; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +/** + * Test cases for {@link InsertionSorter}. + */ +public class TestInsertionSorter { + + private BufferAllocator allocator; + + @Before + public void prepare() { + allocator = new RootAllocator(1024 * 1024); + } + + @After + public void shutdown() { + allocator.close(); + } + + private static final int VECTOR_LENGTH = 10; + + private void testSortIntVectorRange(int start, int end, int[] expected) { + try (IntVector vector = new IntVector("vector", allocator); + IntVector buffer = new IntVector("buffer", allocator)) { + + buffer.allocateNew(1); + + ValueVectorDataPopulator.setVector(vector, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0); + assertEquals(VECTOR_LENGTH, vector.getValueCount()); + + VectorValueComparator comparator = + DefaultVectorComparators.createDefaultComparator(vector); + InsertionSorter.insertionSort(vector, start, end, comparator, buffer); + + assertEquals(VECTOR_LENGTH, expected.length); + for (int i = 0; i < VECTOR_LENGTH; i++) { + assertFalse(vector.isNull(i)); + assertEquals(expected[i], vector.get(i)); + } + } + } + + @Test + public void testSortIntVector() { + testSortIntVectorRange(2, 5, new int[] {9, 8, 4, 5, 6, 7, 3, 2, 1, 0}); + testSortIntVectorRange(3, 7, new int[] {9, 8, 7, 2, 3, 4, 5, 6, 1, 0}); + testSortIntVectorRange(3, 4, new int[] {9, 8, 7, 5, 6, 4, 3, 2, 1, 0}); + testSortIntVectorRange(7, 7, new int[] {9, 8, 7, 6, 5, 4, 3, 2, 1, 0}); + testSortIntVectorRange(0, 5, new int[] {4, 5, 6, 7, 8, 9, 3, 2, 1, 0}); + testSortIntVectorRange(8, 9, new int[] {9, 8, 7, 6, 5, 4, 3, 2, 0, 1}); + testSortIntVectorRange(0, 9, new int[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + } + + private void testSortIndicesRange(int start, int end, int[] expectedIndices) { + try (IntVector vector = new IntVector("vector", allocator); + IntVector indices = new IntVector("indices", allocator)) { + + ValueVectorDataPopulator.setVector(vector, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0); + ValueVectorDataPopulator.setVector(indices, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9); + + assertEquals(VECTOR_LENGTH, vector.getValueCount()); + assertEquals(VECTOR_LENGTH, indices.getValueCount()); + + VectorValueComparator comparator = + DefaultVectorComparators.createDefaultComparator(vector); + comparator.attachVector(vector); + + InsertionSorter.insertionSort(indices, start, end, comparator); + + // verify results + assertEquals(VECTOR_LENGTH, expectedIndices.length); + for (int i = 0; i < VECTOR_LENGTH; i++) { + assertFalse(indices.isNull(i)); + assertEquals(expectedIndices[i], indices.get(i)); + } + } + } + + @Test + public void testSortIndices() { + testSortIndicesRange(2, 5, new int[] {0, 1, 5, 4, 3, 2, 6, 7, 8, 9}); + testSortIndicesRange(3, 7, new int[] {0, 1, 2, 7, 6, 5, 4, 3, 8, 9}); + testSortIndicesRange(3, 4, new int[] {0, 1, 2, 4, 3, 5, 6, 7, 8, 9}); + testSortIndicesRange(7, 7, new int[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + testSortIndicesRange(0, 5, new int[] {5, 4, 3, 2, 1, 0, 6, 7, 8, 9}); + testSortIndicesRange(8, 9, new int[] {0, 1, 2, 3, 4, 5, 6, 7, 9, 8}); + testSortIndicesRange(0, 9, new int[] {9, 8, 7, 6, 5, 4, 3, 2, 1, 0}); + } +} diff --git a/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestOffHeapIntStack.java b/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestOffHeapIntStack.java new file mode 100644 index 000000000..321ca226d --- /dev/null +++ b/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestOffHeapIntStack.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.algorithm.sort; + +import static junit.framework.TestCase.assertEquals; +import static junit.framework.TestCase.assertTrue; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +/** + * Test cases for {@link OffHeapIntStack}. + */ +public class TestOffHeapIntStack { + + private BufferAllocator allocator; + + @Before + public void prepare() { + allocator = new RootAllocator(1024 * 1024); + } + + @After + public void shutdown() { + allocator.close(); + } + + @Test + public void testPushPop() { + try (OffHeapIntStack stack = new OffHeapIntStack(allocator)) { + assertTrue(stack.isEmpty()); + + final int elemCount = 500; + for (int i = 0; i < elemCount; i++) { + stack.push(i); + assertEquals(i, stack.getTop()); + } + + assertEquals(elemCount, stack.getCount()); + + for (int i = 0; i < elemCount; i++) { + assertEquals(elemCount - i - 1, stack.getTop()); + assertEquals(elemCount - i - 1, stack.pop()); + } + + assertTrue(stack.isEmpty()); + } + } +} diff --git a/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestSortingUtil.java b/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestSortingUtil.java new file mode 100644 index 000000000..ea8655106 --- /dev/null +++ b/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestSortingUtil.java @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.algorithm.sort; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.lang.reflect.Array; +import java.util.Arrays; +import java.util.Random; +import java.util.function.BiConsumer; +import java.util.function.Supplier; + +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.Float4Vector; +import org.apache.arrow.vector.Float8Vector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.SmallIntVector; +import org.apache.arrow.vector.TinyIntVector; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.testing.RandomDataGenerator; +import org.apache.arrow.vector.testing.ValueVectorDataPopulator; + +/** + * Utilities for sorting related utilities. + */ +public class TestSortingUtil { + + static final Random random = new Random(0); + + static final DataGenerator TINY_INT_GENERATOR = new DataGenerator<>( + RandomDataGenerator.TINY_INT_GENERATOR, + (vector, array) -> ValueVectorDataPopulator.setVector(vector, array), Byte.class); + + static final DataGenerator SMALL_INT_GENERATOR = new DataGenerator<>( + RandomDataGenerator.SMALL_INT_GENERATOR, + (vector, array) -> ValueVectorDataPopulator.setVector(vector, array), Short.class); + + static final DataGenerator INT_GENERATOR = new DataGenerator<>( + RandomDataGenerator.INT_GENERATOR, + (vector, array) -> ValueVectorDataPopulator.setVector(vector, array), Integer.class); + + static final DataGenerator LONG_GENERATOR = new DataGenerator<>( + RandomDataGenerator.LONG_GENERATOR, + (vector, array) -> ValueVectorDataPopulator.setVector(vector, array), Long.class); + + static final DataGenerator FLOAT_GENERATOR = new DataGenerator<>( + RandomDataGenerator.FLOAT_GENERATOR, + (vector, array) -> ValueVectorDataPopulator.setVector(vector, array), Float.class); + + static final DataGenerator DOUBLE_GENERATOR = new DataGenerator<>( + RandomDataGenerator.DOUBLE_GENERATOR, + (vector, array) -> ValueVectorDataPopulator.setVector(vector, array), Double.class); + + static final DataGenerator STRING_GENERATOR = new DataGenerator<>( + () -> { + int strLength = random.nextInt(20) + 1; + return generateRandomString(strLength); + }, + (vector, array) -> ValueVectorDataPopulator.setVector(vector, array), String.class); + + private TestSortingUtil() { + } + + /** + * Verify that a vector is equal to an array. + */ + public static void verifyResults(V vector, U[] expected) { + assertEquals(vector.getValueCount(), expected.length); + for (int i = 0; i < expected.length; i++) { + assertEquals(vector.getObject(i), expected[i]); + } + } + + /** + * Sort an array with null values come first. + */ + public static > void sortArray(U[] array) { + Arrays.sort(array, (a, b) -> { + if (a == null || b == null) { + if (a == null && b == null) { + return 0; + } + + // exactly one is null + if (a == null) { + return -1; + } else { + return 1; + } + } + return a.compareTo(b); + }); + } + + /** + * Generate a string with alphabetic characters only. + */ + static String generateRandomString(int length) { + byte[] str = new byte[length]; + final int lower = 'a'; + final int upper = 'z'; + + for (int i = 0; i < length; i++) { + // make r non-negative + int r = random.nextInt() & Integer.MAX_VALUE; + str[i] = (byte) (r % (upper - lower + 1) + lower); + } + + return new String(str); + } + + /** + * Utility to generate data for testing. + * @param vector type. + * @param data element type. + */ + static class DataGenerator> { + + final Supplier dataGenerator; + + final BiConsumer vectorPopulator; + + final Class clazz; + + DataGenerator( + Supplier dataGenerator, BiConsumer vectorPopulator, Class clazz) { + this.dataGenerator = dataGenerator; + this.vectorPopulator = vectorPopulator; + this.clazz = clazz; + } + + /** + * Populate the vector according to the specified parameters. + * @param vector the vector to populate. + * @param length vector length. + * @param nullFraction the fraction of null values. + * @return An array with the same data as the vector. + */ + U[] populate(V vector, int length, double nullFraction) { + U[] array = (U[]) Array.newInstance(clazz, length); + for (int i = 0; i < length; i++) { + double r = Math.random(); + U value = r < nullFraction ? null : dataGenerator.get(); + array[i] = value; + } + vectorPopulator.accept(vector, array); + return array; + } + } +} diff --git a/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestStableVectorComparator.java b/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestStableVectorComparator.java new file mode 100644 index 000000000..074193594 --- /dev/null +++ b/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestStableVectorComparator.java @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.algorithm.sort; + +import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.VarCharVector; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +/** + * Test cases for {@link StableVectorComparator}. + */ +public class TestStableVectorComparator { + + private BufferAllocator allocator; + + @Before + public void prepare() { + allocator = new RootAllocator(1024 * 1024); + } + + @After + public void shutdown() { + allocator.close(); + } + + @Test + public void testCompare() { + try (VarCharVector vec = new VarCharVector("", allocator)) { + vec.allocateNew(100, 5); + vec.setValueCount(10); + + // fill data to sort + vec.set(0, "ba".getBytes()); + vec.set(1, "abc".getBytes()); + vec.set(2, "aa".getBytes()); + vec.set(3, "abc".getBytes()); + vec.set(4, "a".getBytes()); + + VectorValueComparator comparator = new TestVarCharSorter(); + VectorValueComparator stableComparator = new StableVectorComparator<>(comparator); + stableComparator.attachVector(vec); + + assertTrue(stableComparator.compare(0, 1) > 0); + assertTrue(stableComparator.compare(1, 2) < 0); + assertTrue(stableComparator.compare(2, 3) < 0); + assertTrue(stableComparator.compare(1, 3) < 0); + assertTrue(stableComparator.compare(3, 1) > 0); + assertTrue(stableComparator.compare(3, 3) == 0); + } + } + + @Test + public void testStableSortString() { + try (VarCharVector vec = new VarCharVector("", allocator)) { + vec.allocateNew(100, 10); + vec.setValueCount(10); + + // fill data to sort + vec.set(0, "a".getBytes()); + vec.set(1, "abc".getBytes()); + vec.set(2, "aa".getBytes()); + vec.set(3, "a1".getBytes()); + vec.set(4, "abcdefg".getBytes()); + vec.set(5, "accc".getBytes()); + vec.set(6, "afds".getBytes()); + vec.set(7, "0".getBytes()); + vec.set(8, "01".getBytes()); + vec.set(9, "0c".getBytes()); + + // sort the vector + VariableWidthOutOfPlaceVectorSorter sorter = new VariableWidthOutOfPlaceVectorSorter(); + VectorValueComparator comparator = new TestVarCharSorter(); + VectorValueComparator stableComparator = new StableVectorComparator<>(comparator); + + try (VarCharVector sortedVec = + (VarCharVector) vec.getField().getFieldType().createNewSingleVector("", allocator, null)) { + sortedVec.allocateNew(vec.getByteCapacity(), vec.getValueCount()); + sortedVec.setLastSet(vec.getValueCount() - 1); + sortedVec.setValueCount(vec.getValueCount()); + + sorter.sortOutOfPlace(vec, sortedVec, stableComparator); + + // verify results + // the results are stable + assertEquals("0", new String(sortedVec.get(0))); + assertEquals("01", new String(sortedVec.get(1))); + assertEquals("0c", new String(sortedVec.get(2))); + assertEquals("a", new String(sortedVec.get(3))); + assertEquals("abc", new String(sortedVec.get(4))); + assertEquals("aa", new String(sortedVec.get(5))); + assertEquals("a1", new String(sortedVec.get(6))); + assertEquals("abcdefg", new String(sortedVec.get(7))); + assertEquals("accc", new String(sortedVec.get(8))); + assertEquals("afds", new String(sortedVec.get(9))); + } + } + } + + /** + * Utility comparator that compares varchars by the first character. + */ + private static class TestVarCharSorter extends VectorValueComparator { + + @Override + public int compareNotNull(int index1, int index2) { + byte b1 = vector1.get(index1)[0]; + byte b2 = vector2.get(index2)[0]; + return b1 - b2; + } + + @Override + public VectorValueComparator createNew() { + return new TestVarCharSorter(); + } + } +} diff --git a/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestVariableWidthOutOfPlaceVectorSorter.java b/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestVariableWidthOutOfPlaceVectorSorter.java new file mode 100644 index 000000000..46b306021 --- /dev/null +++ b/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestVariableWidthOutOfPlaceVectorSorter.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.algorithm.sort; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.BaseVariableWidthVector; +import org.apache.arrow.vector.VarCharVector; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +/** + * Test cases for {@link VariableWidthOutOfPlaceVectorSorter}. + */ +public class TestVariableWidthOutOfPlaceVectorSorter { + + private BufferAllocator allocator; + + @Before + public void prepare() { + allocator = new RootAllocator(1024 * 1024); + } + + @After + public void shutdown() { + allocator.close(); + } + + @Test + public void testSortString() { + try (VarCharVector vec = new VarCharVector("", allocator)) { + vec.allocateNew(100, 10); + vec.setValueCount(10); + + // fill data to sort + vec.set(0, "hello".getBytes()); + vec.set(1, "abc".getBytes()); + vec.setNull(2); + vec.set(3, "world".getBytes()); + vec.set(4, "12".getBytes()); + vec.set(5, "dictionary".getBytes()); + vec.setNull(6); + vec.set(7, "hello".getBytes()); + vec.set(8, "good".getBytes()); + vec.set(9, "yes".getBytes()); + + // sort the vector + VariableWidthOutOfPlaceVectorSorter sorter = new VariableWidthOutOfPlaceVectorSorter(); + VectorValueComparator comparator = + DefaultVectorComparators.createDefaultComparator(vec); + + VarCharVector sortedVec = + (VarCharVector) vec.getField().getFieldType().createNewSingleVector("", allocator, null); + sortedVec.allocateNew(vec.getByteCapacity(), vec.getValueCount()); + sortedVec.setLastSet(vec.getValueCount() - 1); + sortedVec.setValueCount(vec.getValueCount()); + + sorter.sortOutOfPlace(vec, sortedVec, comparator); + + // verify results + Assert.assertEquals(vec.getValueCount(), sortedVec.getValueCount()); + Assert.assertEquals(vec.getByteCapacity(), sortedVec.getByteCapacity()); + Assert.assertEquals(vec.getLastSet(), sortedVec.getLastSet()); + + assertTrue(sortedVec.isNull(0)); + assertTrue(sortedVec.isNull(1)); + assertEquals("12", new String(sortedVec.get(2))); + assertEquals("abc", new String(sortedVec.get(3))); + assertEquals("dictionary", new String(sortedVec.get(4))); + assertEquals("good", new String(sortedVec.get(5))); + assertEquals("hello", new String(sortedVec.get(6))); + assertEquals("hello", new String(sortedVec.get(7))); + assertEquals("world", new String(sortedVec.get(8))); + assertEquals("yes", new String(sortedVec.get(9))); + + sortedVec.close(); + } + } +} diff --git a/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestVariableWidthSorting.java b/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestVariableWidthSorting.java new file mode 100644 index 000000000..068fe8b69 --- /dev/null +++ b/src/arrow/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestVariableWidthSorting.java @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.algorithm.sort; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Comparator; +import java.util.List; +import java.util.function.Function; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.BaseVariableWidthVector; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.util.Text; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +/** + * Test sorting variable width vectors with random data. + */ +@RunWith(Parameterized.class) +public class TestVariableWidthSorting> { + + static final int[] VECTOR_LENGTHS = new int[] {2, 5, 10, 50, 100, 1000, 3000}; + + static final double[] NULL_FRACTIONS = {0, 0.1, 0.3, 0.5, 0.7, 0.9, 1}; + + private final int length; + + private final double nullFraction; + + private final Function vectorGenerator; + + private final TestSortingUtil.DataGenerator dataGenerator; + + private BufferAllocator allocator; + + @Before + public void prepare() { + allocator = new RootAllocator(Integer.MAX_VALUE); + } + + @After + public void shutdown() { + allocator.close(); + } + + public TestVariableWidthSorting( + int length, double nullFraction, String desc, + Function vectorGenerator, TestSortingUtil.DataGenerator dataGenerator) { + this.length = length; + this.nullFraction = nullFraction; + this.vectorGenerator = vectorGenerator; + this.dataGenerator = dataGenerator; + } + + @Test + public void testSort() { + sortOutOfPlace(); + } + + void sortOutOfPlace() { + try (V vector = vectorGenerator.apply(allocator)) { + U[] array = dataGenerator.populate(vector, length, nullFraction); + Arrays.sort(array, (Comparator) new StringComparator()); + + // sort the vector + VariableWidthOutOfPlaceVectorSorter sorter = new VariableWidthOutOfPlaceVectorSorter(); + VectorValueComparator comparator = DefaultVectorComparators.createDefaultComparator(vector); + + try (V sortedVec = (V) vector.getField().getFieldType().createNewSingleVector("", allocator, null)) { + int dataSize = vector.getOffsetBuffer().getInt(vector.getValueCount() * 4); + sortedVec.allocateNew(dataSize, vector.getValueCount()); + sortedVec.setValueCount(vector.getValueCount()); + + sorter.sortOutOfPlace(vector, sortedVec, comparator); + + // verify results + verifyResults(sortedVec, (String[]) array); + } + } + } + + @Parameterized.Parameters(name = "length = {0}, null fraction = {1}, vector = {2}") + public static Collection getParameters() { + List params = new ArrayList<>(); + for (int length : VECTOR_LENGTHS) { + for (double nullFrac : NULL_FRACTIONS) { + params.add(new Object[]{ + length, nullFrac, "VarCharVector", + (Function) (allocator -> new VarCharVector("vector", allocator)), + TestSortingUtil.STRING_GENERATOR + }); + } + } + return params; + } + + /** + * Verify results as byte arrays. + */ + public static void verifyResults(V vector, String[] expected) { + assertEquals(vector.getValueCount(), expected.length); + for (int i = 0; i < expected.length; i++) { + if (expected[i] == null) { + assertTrue(vector.isNull(i)); + } else { + assertArrayEquals(((Text) vector.getObject(i)).getBytes(), expected[i].getBytes()); + } + } + } + + /** + * String comparator with the same behavior as that of + * {@link DefaultVectorComparators.VariableWidthComparator}. + */ + static class StringComparator implements Comparator { + + @Override + public int compare(String str1, String str2) { + if (str1 == null || str2 == null) { + if (str1 == null && str2 == null) { + return 0; + } + + return str1 == null ? -1 : 1; + } + + byte[] bytes1 = str1.getBytes(); + byte[] bytes2 = str2.getBytes(); + + for (int i = 0; i < bytes1.length && i < bytes2.length; i++) { + if (bytes1[i] != bytes2[i]) { + return (bytes1[i] & 0xff) < (bytes2[i] & 0xff) ? -1 : 1; + } + } + return bytes1.length - bytes2.length; + } + } +} -- cgit v1.2.3