diff options
Diffstat (limited to 'tests/test_executor.py')
-rw-r--r-- | tests/test_executor.py | 81 |
1 files changed, 56 insertions, 25 deletions
diff --git a/tests/test_executor.py b/tests/test_executor.py index 9a2b46b..981c1d4 100644 --- a/tests/test_executor.py +++ b/tests/test_executor.py @@ -1,3 +1,4 @@ +import os import datetime import unittest from datetime import date @@ -17,40 +18,53 @@ from tests.helpers import ( FIXTURES_DIR, SKIP_INTEGRATION, TPCH_SCHEMA, + TPCDS_SCHEMA, load_sql_fixture_pairs, + string_to_bool, ) -DIR = FIXTURES_DIR + "/optimizer/tpc-h/" +DIR_TPCH = FIXTURES_DIR + "/optimizer/tpc-h/" +DIR_TPCDS = FIXTURES_DIR + "/optimizer/tpc-ds/" @unittest.skipIf(SKIP_INTEGRATION, "Skipping Integration Tests since `SKIP_INTEGRATION` is set") class TestExecutor(unittest.TestCase): @classmethod def setUpClass(cls): - cls.conn = duckdb.connect() + cls.tpch_conn = duckdb.connect() + cls.tpcds_conn = duckdb.connect() for table, columns in TPCH_SCHEMA.items(): - cls.conn.execute( + cls.tpch_conn.execute( f""" CREATE VIEW {table} AS SELECT * - FROM READ_CSV('{DIR}{table}.csv.gz', delim='|', header=True, columns={columns}) + FROM READ_CSV('{DIR_TPCH}{table}.csv.gz', delim='|', header=True, columns={columns}) + """ + ) + + for table, columns in TPCDS_SCHEMA.items(): + cls.tpcds_conn.execute( + f""" + CREATE VIEW {table} AS + SELECT * + FROM READ_CSV('{DIR_TPCDS}{table}.csv.gz', delim='|', header=True, columns={columns}) """ ) cls.cache = {} - cls.sqls = [ - (sql, expected) - for _, sql, expected in load_sql_fixture_pairs("optimizer/tpc-h/tpc-h.sql") - ] + cls.tpch_sqls = list(load_sql_fixture_pairs("optimizer/tpc-h/tpc-h.sql")) + cls.tpcds_sqls = list(load_sql_fixture_pairs("optimizer/tpc-ds/tpc-ds.sql")) @classmethod def tearDownClass(cls): - cls.conn.close() + cls.tpch_conn.close() + cls.tpcds_conn.close() - def cached_execute(self, sql): + def cached_execute(self, sql, tpch=True): + conn = self.tpch_conn if tpch else self.tpcds_conn if sql not in self.cache: - self.cache[sql] = self.conn.execute(transpile(sql, write="duckdb")[0]).fetchdf() + self.cache[sql] = conn.execute(transpile(sql, write="duckdb")[0]).fetchdf() return self.cache[sql] def rename_anonymous(self, source, target): @@ -66,18 +80,28 @@ class TestExecutor(unittest.TestCase): self.assertEqual(generate(parse_one("x is null")), "scope[None][x] is None") def test_optimized_tpch(self): - for i, (sql, optimized) in enumerate(self.sqls, start=1): + for i, (_, sql, optimized) in enumerate(self.tpch_sqls, start=1): with self.subTest(f"{i}, {sql}"): - a = self.cached_execute(sql) - b = self.conn.execute(transpile(optimized, write="duckdb")[0]).fetchdf() + a = self.cached_execute(sql, tpch=True) + b = self.tpch_conn.execute(transpile(optimized, write="duckdb")[0]).fetchdf() self.rename_anonymous(b, a) assert_frame_equal(a, b) + def subtestHelper(self, i, table, tpch=True): + with self.subTest(f"{'tpc-h' if tpch else 'tpc-ds'} {i + 1}"): + _, sql, _ = self.tpch_sqls[i] if tpch else self.tpcds_sqls[i] + a = self.cached_execute(sql, tpch=tpch) + b = pd.DataFrame( + ((np.nan if c is None else c for c in r) for r in table.rows), + columns=table.columns, + ) + assert_frame_equal(a, b, check_dtype=False, check_index_type=False) + def test_execute_tpch(self): def to_csv(expression): if isinstance(expression, exp.Table) and expression.name not in ("revenue"): return parse_one( - f"READ_CSV('{DIR}{expression.name}.csv.gz', 'delimiter', '|') AS {expression.alias_or_name}" + f"READ_CSV('{DIR_TPCH}{expression.name}.csv.gz', 'delimiter', '|') AS {expression.alias_or_name}" ) return expression @@ -87,19 +111,26 @@ class TestExecutor(unittest.TestCase): execute, ( (parse_one(sql).transform(to_csv).sql(pretty=True), TPCH_SCHEMA) - for sql, _ in self.sqls + for _, sql, _ in self.tpch_sqls ), ) ): - with self.subTest(f"tpch-h {i + 1}"): - sql, _ = self.sqls[i] - a = self.cached_execute(sql) - b = pd.DataFrame( - ((np.nan if c is None else c for c in r) for r in table.rows), - columns=table.columns, - ) - - assert_frame_equal(a, b, check_dtype=False, check_index_type=False) + self.subtestHelper(i, table, tpch=True) + + def test_execute_tpcds(self): + def to_csv(expression): + if isinstance(expression, exp.Table) and os.path.exists( + f"{DIR_TPCDS}{expression.name}.csv.gz" + ): + return parse_one( + f"READ_CSV('{DIR_TPCDS}{expression.name}.csv.gz', 'delimiter', '|') AS {expression.alias_or_name}" + ) + return expression + + for i, (meta, sql, _) in enumerate(self.tpcds_sqls): + if string_to_bool(meta.get("execute")): + table = execute(parse_one(sql).transform(to_csv).sql(pretty=True), TPCDS_SCHEMA) + self.subtestHelper(i, table, tpch=False) def test_execute_callable(self): tables = { |