diff options
Diffstat (limited to 'src/arrow/js/test/unit/dataframe-tests.ts')
-rw-r--r-- | src/arrow/js/test/unit/dataframe-tests.ts | 282 |
1 files changed, 282 insertions, 0 deletions
diff --git a/src/arrow/js/test/unit/dataframe-tests.ts b/src/arrow/js/test/unit/dataframe-tests.ts new file mode 100644 index 000000000..9e87e372d --- /dev/null +++ b/src/arrow/js/test/unit/dataframe-tests.ts @@ -0,0 +1,282 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +import '../jest-extensions'; +import { + predicate, DataFrame, RecordBatch +} from 'apache-arrow'; +import { test_data } from './table-tests'; +import { jest } from '@jest/globals'; + +const { col, lit, custom, and, or, And, Or } = predicate; + +const F32 = 0, I32 = 1, DICT = 2; + +describe(`DataFrame`, () => { + + for (let datum of test_data) { + describe(datum.name, () => { + + describe(`scan()`, () => { + test(`yields all values`, () => { + const df = new DataFrame(datum.table()); + let expected_idx = 0; + df.scan((idx, batch) => { + const columns = batch.schema.fields.map((_, i) => batch.getChildAt(i)!); + expect(columns.map((c) => c.get(idx))).toEqual(values[expected_idx++]); + }); + }); + test(`calls bind function with every batch`, () => { + const df = new DataFrame(datum.table()); + let bind = jest.fn(); + df.scan(() => { }, bind); + for (let batch of df.chunks) { + expect(bind).toHaveBeenCalledWith(batch); + } + }); + }); + describe(`scanReverse()`, () => { + test(`yields all values`, () => { + const df = new DataFrame(datum.table()); + let expected_idx = values.length; + df.scanReverse((idx, batch) => { + const columns = batch.schema.fields.map((_, i) => batch.getChildAt(i)!); + expect(columns.map((c) => c.get(idx))).toEqual(values[--expected_idx]); + }); + }); + test(`calls bind function with every batch`, () => { + const df = new DataFrame(datum.table()); + let bind = jest.fn(); + df.scanReverse(() => { }, bind); + for (let batch of df.chunks) { + expect(bind).toHaveBeenCalledWith(batch); + } + }); + }); + test(`count() returns the correct length`, () => { + const df = new DataFrame(datum.table()); + const values = datum.values(); + expect(df.count()).toEqual(values.length); + }); + test(`getColumnIndex`, () => { + const df = new DataFrame(datum.table()); + expect(df.getColumnIndex('i32')).toEqual(I32); + expect(df.getColumnIndex('f32')).toEqual(F32); + expect(df.getColumnIndex('dictionary')).toEqual(DICT); + }); + const df = new DataFrame(datum.table()); + const values = datum.values(); + let get_i32: (idx: number) => number, get_f32: (idx: number) => number; + const filter_tests = [ + { + name: `filter on f32 >= 0`, + filtered: df.filter(col('f32').ge(0)), + expected: values.filter((row) => row[F32] >= 0) + }, { + name: `filter on 0 <= f32`, + filtered: df.filter(lit(0).le(col('f32'))), + expected: values.filter((row) => 0 <= row[F32]) + }, { + name: `filter on i32 <= 0`, + filtered: df.filter(col('i32').le(0)), + expected: values.filter((row) => row[I32] <= 0) + }, { + name: `filter on 0 >= i32`, + filtered: df.filter(lit(0).ge(col('i32'))), + expected: values.filter((row) => 0 >= row[I32]) + }, { + name: `filter on f32 < 0`, + filtered: df.filter(col('f32').lt(0)), + expected: values.filter((row) => row[F32] < 0) + }, { + name: `filter on i32 > 1 (empty)`, + filtered: df.filter(col('i32').gt(0)), + expected: values.filter((row) => row[I32] > 0) + }, { + name: `filter on f32 <= -.25 || f3 >= .25`, + filtered: df.filter(col('f32').le(-.25).or(col('f32').ge(.25))), + expected: values.filter((row) => row[F32] <= -.25 || row[F32] >= .25) + }, { + name: `filter on !(f32 <= -.25 || f3 >= .25) (not)`, + filtered: df.filter(col('f32').le(-.25).or(col('f32').ge(.25)).not()), + expected: values.filter((row) => !(row[F32] <= -.25 || row[F32] >= .25)) + }, { + name: `filter method combines predicates (f32 >= 0 && i32 <= 0)`, + filtered: df.filter(col('i32').le(0)).filter(col('f32').ge(0)), + expected: values.filter((row) => row[I32] <= 0 && row[F32] >= 0) + }, { + name: `filter on dictionary == 'a'`, + filtered: df.filter(col('dictionary').eq('a')), + expected: values.filter((row) => row[DICT] === 'a') + }, { + name: `filter on 'a' == dictionary (commutativity)`, + filtered: df.filter(lit('a').eq(col('dictionary'))), + expected: values.filter((row) => row[DICT] === 'a') + }, { + name: `filter on dictionary != 'b'`, + filtered: df.filter(col('dictionary').ne('b')), + expected: values.filter((row) => row[DICT] !== 'b') + }, { + name: `filter on f32 >= i32`, + filtered: df.filter(col('f32').ge(col('i32'))), + expected: values.filter((row) => row[F32] >= row[I32]) + }, { + name: `filter on f32 <= i32`, + filtered: df.filter(col('f32').le(col('i32'))), + expected: values.filter((row) => row[F32] <= row[I32]) + }, { + name: `filter on f32*i32 > 0 (custom predicate)`, + filtered: df.filter(custom( + (idx: number) => (get_f32(idx) * get_i32(idx) > 0), + (batch: RecordBatch) => { + get_f32 = col('f32').bind(batch); + get_i32 = col('i32').bind(batch); + })), + expected: values.filter((row) => (row[F32] as number) * (row[I32] as number) > 0) + }, { + name: `filter out all records`, + filtered: df.filter(lit(1).eq(0)), + expected: [] + } + ]; + for (let this_test of filter_tests) { + const { name, filtered, expected } = this_test; + describe(name, () => { + test(`count() returns the correct length`, () => { + expect(filtered.count()).toEqual(expected.length); + }); + describe(`scan()`, () => { + test(`iterates over expected values`, () => { + let expected_idx = 0; + filtered.scan((idx, batch) => { + const columns = batch.schema.fields.map((_, i) => batch.getChildAt(i)!); + expect(columns.map((c) => c.get(idx))).toEqual(expected[expected_idx++]); + }); + }); + test(`calls bind function lazily`, () => { + let bind = jest.fn(); + filtered.scan(() => { }, bind); + if (expected.length) { + expect(bind).toHaveBeenCalled(); + } else { + expect(bind).not.toHaveBeenCalled(); + } + }); + }); + describe(`scanReverse()`, () => { + test(`iterates over expected values in reverse`, () => { + let expected_idx = expected.length; + filtered.scanReverse((idx, batch) => { + const columns = batch.schema.fields.map((_, i) => batch.getChildAt(i)!); + expect(columns.map((c) => c.get(idx))).toEqual(expected[--expected_idx]); + }); + }); + test(`calls bind function lazily`, () => { + let bind = jest.fn(); + filtered.scanReverse(() => { }, bind); + if (expected.length) { + expect(bind).toHaveBeenCalled(); + } else { + expect(bind).not.toHaveBeenCalled(); + } + }); + }); + }); + } + test(`countBy on dictionary returns the correct counts`, () => { + // Make sure countBy works both with and without the Col wrapper + // class + let expected: { [key: string]: number } = { 'a': 0, 'b': 0, 'c': 0 }; + for (let row of values) { + expected[row[DICT]] += 1; + } + + expect(df.countBy(col('dictionary')).toJSON()).toEqual(expected); + expect(df.countBy('dictionary').toJSON()).toEqual(expected); + }); + test(`countBy on dictionary with filter returns the correct counts`, () => { + let expected: { [key: string]: number } = { 'a': 0, 'b': 0, 'c': 0 }; + for (let row of values) { + if (row[I32] === 1) { expected[row[DICT]] += 1; } + } + + expect(df.filter(col('i32').eq(1)).countBy('dictionary').toJSON()).toEqual(expected); + }); + test(`countBy on non dictionary column throws error`, () => { + expect(() => { df.countBy('i32'); }).toThrow(); + expect(() => { df.filter(col('dict').eq('a')).countBy('i32'); }).toThrow(); + }); + test(`countBy on non-existent column throws error`, () => { + expect(() => { df.countBy('FAKE' as any); }).toThrow(); + }); + test(`table.select() basic tests`, () => { + let selected = df.select('f32', 'dictionary'); + expect(selected.schema.fields).toHaveLength(2); + expect(selected.schema.fields[0]).toEqual(df.schema.fields[0]); + expect(selected.schema.fields[1]).toEqual(df.schema.fields[2]); + + expect(selected).toHaveLength(values.length); + let idx = 0, expected_row; + for (let row of selected) { + expected_row = values[idx++]; + expect(row.f32).toEqual(expected_row[F32]); + expect(row.dictionary).toEqual(expected_row[DICT]); + } + }); + test(`table.filter(..).count() on always false predicates returns 0`, () => { + expect(df.filter(col('i32').ge(100)).count()).toEqual(0); + expect(df.filter(col('dictionary').eq('z')).count()).toEqual(0); + }); + describe(`lit-lit comparison`, () => { + test(`always-false count() returns 0`, () => { + expect(df.filter(lit('abc').eq('def')).count()).toEqual(0); + expect(df.filter(lit(0).ge(1)).count()).toEqual(0); + }); + test(`always-true count() returns length`, () => { + expect(df.filter(lit('abc').eq('abc')).count()).toEqual(df.length); + expect(df.filter(lit(-100).le(0)).count()).toEqual(df.length); + }); + }); + describe(`col-col comparison`, () => { + test(`always-false count() returns 0`, () => { + expect(df.filter(col('dictionary').eq(col('i32'))).count()).toEqual(0); + }); + test(`always-true count() returns length`, () => { + expect(df.filter(col('dictionary').eq(col('dictionary'))).count()).toEqual(df.length); + }); + }); + }); + } +}); + +describe(`Predicate`, () => { + const p1 = col('a').gt(100); + const p2 = col('a').lt(1000); + const p3 = col('b').eq('foo'); + const p4 = col('c').eq('bar'); + const expected = [p1, p2, p3, p4]; + test(`and flattens children`, () => { + expect(and(p1, p2, p3, p4).children).toEqual(expected); + expect(and(p1.and(p2), new And(p3, p4)).children).toEqual(expected); + expect(and(p1.and(p2, p3, p4)).children).toEqual(expected); + }); + test(`or flattens children`, () => { + expect(or(p1, p2, p3, p4).children).toEqual(expected); + expect(or(p1.or(p2), new Or(p3, p4)).children).toEqual(expected); + expect(or(p1.or(p2, p3, p4)).children).toEqual(expected); + }); +}); |