summaryrefslogtreecommitdiffstats
path: root/src/arrow/js/test/unit/dataframe-tests.ts
diff options
context:
space:
mode:
Diffstat (limited to 'src/arrow/js/test/unit/dataframe-tests.ts')
-rw-r--r--src/arrow/js/test/unit/dataframe-tests.ts282
1 files changed, 282 insertions, 0 deletions
diff --git a/src/arrow/js/test/unit/dataframe-tests.ts b/src/arrow/js/test/unit/dataframe-tests.ts
new file mode 100644
index 000000000..9e87e372d
--- /dev/null
+++ b/src/arrow/js/test/unit/dataframe-tests.ts
@@ -0,0 +1,282 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+import '../jest-extensions';
+import {
+ predicate, DataFrame, RecordBatch
+} from 'apache-arrow';
+import { test_data } from './table-tests';
+import { jest } from '@jest/globals';
+
+const { col, lit, custom, and, or, And, Or } = predicate;
+
+const F32 = 0, I32 = 1, DICT = 2;
+
+describe(`DataFrame`, () => {
+
+ for (let datum of test_data) {
+ describe(datum.name, () => {
+
+ describe(`scan()`, () => {
+ test(`yields all values`, () => {
+ const df = new DataFrame(datum.table());
+ let expected_idx = 0;
+ df.scan((idx, batch) => {
+ const columns = batch.schema.fields.map((_, i) => batch.getChildAt(i)!);
+ expect(columns.map((c) => c.get(idx))).toEqual(values[expected_idx++]);
+ });
+ });
+ test(`calls bind function with every batch`, () => {
+ const df = new DataFrame(datum.table());
+ let bind = jest.fn();
+ df.scan(() => { }, bind);
+ for (let batch of df.chunks) {
+ expect(bind).toHaveBeenCalledWith(batch);
+ }
+ });
+ });
+ describe(`scanReverse()`, () => {
+ test(`yields all values`, () => {
+ const df = new DataFrame(datum.table());
+ let expected_idx = values.length;
+ df.scanReverse((idx, batch) => {
+ const columns = batch.schema.fields.map((_, i) => batch.getChildAt(i)!);
+ expect(columns.map((c) => c.get(idx))).toEqual(values[--expected_idx]);
+ });
+ });
+ test(`calls bind function with every batch`, () => {
+ const df = new DataFrame(datum.table());
+ let bind = jest.fn();
+ df.scanReverse(() => { }, bind);
+ for (let batch of df.chunks) {
+ expect(bind).toHaveBeenCalledWith(batch);
+ }
+ });
+ });
+ test(`count() returns the correct length`, () => {
+ const df = new DataFrame(datum.table());
+ const values = datum.values();
+ expect(df.count()).toEqual(values.length);
+ });
+ test(`getColumnIndex`, () => {
+ const df = new DataFrame(datum.table());
+ expect(df.getColumnIndex('i32')).toEqual(I32);
+ expect(df.getColumnIndex('f32')).toEqual(F32);
+ expect(df.getColumnIndex('dictionary')).toEqual(DICT);
+ });
+ const df = new DataFrame(datum.table());
+ const values = datum.values();
+ let get_i32: (idx: number) => number, get_f32: (idx: number) => number;
+ const filter_tests = [
+ {
+ name: `filter on f32 >= 0`,
+ filtered: df.filter(col('f32').ge(0)),
+ expected: values.filter((row) => row[F32] >= 0)
+ }, {
+ name: `filter on 0 <= f32`,
+ filtered: df.filter(lit(0).le(col('f32'))),
+ expected: values.filter((row) => 0 <= row[F32])
+ }, {
+ name: `filter on i32 <= 0`,
+ filtered: df.filter(col('i32').le(0)),
+ expected: values.filter((row) => row[I32] <= 0)
+ }, {
+ name: `filter on 0 >= i32`,
+ filtered: df.filter(lit(0).ge(col('i32'))),
+ expected: values.filter((row) => 0 >= row[I32])
+ }, {
+ name: `filter on f32 < 0`,
+ filtered: df.filter(col('f32').lt(0)),
+ expected: values.filter((row) => row[F32] < 0)
+ }, {
+ name: `filter on i32 > 1 (empty)`,
+ filtered: df.filter(col('i32').gt(0)),
+ expected: values.filter((row) => row[I32] > 0)
+ }, {
+ name: `filter on f32 <= -.25 || f3 >= .25`,
+ filtered: df.filter(col('f32').le(-.25).or(col('f32').ge(.25))),
+ expected: values.filter((row) => row[F32] <= -.25 || row[F32] >= .25)
+ }, {
+ name: `filter on !(f32 <= -.25 || f3 >= .25) (not)`,
+ filtered: df.filter(col('f32').le(-.25).or(col('f32').ge(.25)).not()),
+ expected: values.filter((row) => !(row[F32] <= -.25 || row[F32] >= .25))
+ }, {
+ name: `filter method combines predicates (f32 >= 0 && i32 <= 0)`,
+ filtered: df.filter(col('i32').le(0)).filter(col('f32').ge(0)),
+ expected: values.filter((row) => row[I32] <= 0 && row[F32] >= 0)
+ }, {
+ name: `filter on dictionary == 'a'`,
+ filtered: df.filter(col('dictionary').eq('a')),
+ expected: values.filter((row) => row[DICT] === 'a')
+ }, {
+ name: `filter on 'a' == dictionary (commutativity)`,
+ filtered: df.filter(lit('a').eq(col('dictionary'))),
+ expected: values.filter((row) => row[DICT] === 'a')
+ }, {
+ name: `filter on dictionary != 'b'`,
+ filtered: df.filter(col('dictionary').ne('b')),
+ expected: values.filter((row) => row[DICT] !== 'b')
+ }, {
+ name: `filter on f32 >= i32`,
+ filtered: df.filter(col('f32').ge(col('i32'))),
+ expected: values.filter((row) => row[F32] >= row[I32])
+ }, {
+ name: `filter on f32 <= i32`,
+ filtered: df.filter(col('f32').le(col('i32'))),
+ expected: values.filter((row) => row[F32] <= row[I32])
+ }, {
+ name: `filter on f32*i32 > 0 (custom predicate)`,
+ filtered: df.filter(custom(
+ (idx: number) => (get_f32(idx) * get_i32(idx) > 0),
+ (batch: RecordBatch) => {
+ get_f32 = col('f32').bind(batch);
+ get_i32 = col('i32').bind(batch);
+ })),
+ expected: values.filter((row) => (row[F32] as number) * (row[I32] as number) > 0)
+ }, {
+ name: `filter out all records`,
+ filtered: df.filter(lit(1).eq(0)),
+ expected: []
+ }
+ ];
+ for (let this_test of filter_tests) {
+ const { name, filtered, expected } = this_test;
+ describe(name, () => {
+ test(`count() returns the correct length`, () => {
+ expect(filtered.count()).toEqual(expected.length);
+ });
+ describe(`scan()`, () => {
+ test(`iterates over expected values`, () => {
+ let expected_idx = 0;
+ filtered.scan((idx, batch) => {
+ const columns = batch.schema.fields.map((_, i) => batch.getChildAt(i)!);
+ expect(columns.map((c) => c.get(idx))).toEqual(expected[expected_idx++]);
+ });
+ });
+ test(`calls bind function lazily`, () => {
+ let bind = jest.fn();
+ filtered.scan(() => { }, bind);
+ if (expected.length) {
+ expect(bind).toHaveBeenCalled();
+ } else {
+ expect(bind).not.toHaveBeenCalled();
+ }
+ });
+ });
+ describe(`scanReverse()`, () => {
+ test(`iterates over expected values in reverse`, () => {
+ let expected_idx = expected.length;
+ filtered.scanReverse((idx, batch) => {
+ const columns = batch.schema.fields.map((_, i) => batch.getChildAt(i)!);
+ expect(columns.map((c) => c.get(idx))).toEqual(expected[--expected_idx]);
+ });
+ });
+ test(`calls bind function lazily`, () => {
+ let bind = jest.fn();
+ filtered.scanReverse(() => { }, bind);
+ if (expected.length) {
+ expect(bind).toHaveBeenCalled();
+ } else {
+ expect(bind).not.toHaveBeenCalled();
+ }
+ });
+ });
+ });
+ }
+ test(`countBy on dictionary returns the correct counts`, () => {
+ // Make sure countBy works both with and without the Col wrapper
+ // class
+ let expected: { [key: string]: number } = { 'a': 0, 'b': 0, 'c': 0 };
+ for (let row of values) {
+ expected[row[DICT]] += 1;
+ }
+
+ expect(df.countBy(col('dictionary')).toJSON()).toEqual(expected);
+ expect(df.countBy('dictionary').toJSON()).toEqual(expected);
+ });
+ test(`countBy on dictionary with filter returns the correct counts`, () => {
+ let expected: { [key: string]: number } = { 'a': 0, 'b': 0, 'c': 0 };
+ for (let row of values) {
+ if (row[I32] === 1) { expected[row[DICT]] += 1; }
+ }
+
+ expect(df.filter(col('i32').eq(1)).countBy('dictionary').toJSON()).toEqual(expected);
+ });
+ test(`countBy on non dictionary column throws error`, () => {
+ expect(() => { df.countBy('i32'); }).toThrow();
+ expect(() => { df.filter(col('dict').eq('a')).countBy('i32'); }).toThrow();
+ });
+ test(`countBy on non-existent column throws error`, () => {
+ expect(() => { df.countBy('FAKE' as any); }).toThrow();
+ });
+ test(`table.select() basic tests`, () => {
+ let selected = df.select('f32', 'dictionary');
+ expect(selected.schema.fields).toHaveLength(2);
+ expect(selected.schema.fields[0]).toEqual(df.schema.fields[0]);
+ expect(selected.schema.fields[1]).toEqual(df.schema.fields[2]);
+
+ expect(selected).toHaveLength(values.length);
+ let idx = 0, expected_row;
+ for (let row of selected) {
+ expected_row = values[idx++];
+ expect(row.f32).toEqual(expected_row[F32]);
+ expect(row.dictionary).toEqual(expected_row[DICT]);
+ }
+ });
+ test(`table.filter(..).count() on always false predicates returns 0`, () => {
+ expect(df.filter(col('i32').ge(100)).count()).toEqual(0);
+ expect(df.filter(col('dictionary').eq('z')).count()).toEqual(0);
+ });
+ describe(`lit-lit comparison`, () => {
+ test(`always-false count() returns 0`, () => {
+ expect(df.filter(lit('abc').eq('def')).count()).toEqual(0);
+ expect(df.filter(lit(0).ge(1)).count()).toEqual(0);
+ });
+ test(`always-true count() returns length`, () => {
+ expect(df.filter(lit('abc').eq('abc')).count()).toEqual(df.length);
+ expect(df.filter(lit(-100).le(0)).count()).toEqual(df.length);
+ });
+ });
+ describe(`col-col comparison`, () => {
+ test(`always-false count() returns 0`, () => {
+ expect(df.filter(col('dictionary').eq(col('i32'))).count()).toEqual(0);
+ });
+ test(`always-true count() returns length`, () => {
+ expect(df.filter(col('dictionary').eq(col('dictionary'))).count()).toEqual(df.length);
+ });
+ });
+ });
+ }
+});
+
+describe(`Predicate`, () => {
+ const p1 = col('a').gt(100);
+ const p2 = col('a').lt(1000);
+ const p3 = col('b').eq('foo');
+ const p4 = col('c').eq('bar');
+ const expected = [p1, p2, p3, p4];
+ test(`and flattens children`, () => {
+ expect(and(p1, p2, p3, p4).children).toEqual(expected);
+ expect(and(p1.and(p2), new And(p3, p4)).children).toEqual(expected);
+ expect(and(p1.and(p2, p3, p4)).children).toEqual(expected);
+ });
+ test(`or flattens children`, () => {
+ expect(or(p1, p2, p3, p4).children).toEqual(expected);
+ expect(or(p1.or(p2), new Or(p3, p4)).children).toEqual(expected);
+ expect(or(p1.or(p2, p3, p4)).children).toEqual(expected);
+ });
+});