summaryrefslogtreecommitdiffstats
path: root/tests/test_executor.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_executor.py')
-rw-r--r--tests/test_executor.py81
1 files changed, 56 insertions, 25 deletions
diff --git a/tests/test_executor.py b/tests/test_executor.py
index 9a2b46b..981c1d4 100644
--- a/tests/test_executor.py
+++ b/tests/test_executor.py
@@ -1,3 +1,4 @@
+import os
import datetime
import unittest
from datetime import date
@@ -17,40 +18,53 @@ from tests.helpers import (
FIXTURES_DIR,
SKIP_INTEGRATION,
TPCH_SCHEMA,
+ TPCDS_SCHEMA,
load_sql_fixture_pairs,
+ string_to_bool,
)
-DIR = FIXTURES_DIR + "/optimizer/tpc-h/"
+DIR_TPCH = FIXTURES_DIR + "/optimizer/tpc-h/"
+DIR_TPCDS = FIXTURES_DIR + "/optimizer/tpc-ds/"
@unittest.skipIf(SKIP_INTEGRATION, "Skipping Integration Tests since `SKIP_INTEGRATION` is set")
class TestExecutor(unittest.TestCase):
@classmethod
def setUpClass(cls):
- cls.conn = duckdb.connect()
+ cls.tpch_conn = duckdb.connect()
+ cls.tpcds_conn = duckdb.connect()
for table, columns in TPCH_SCHEMA.items():
- cls.conn.execute(
+ cls.tpch_conn.execute(
f"""
CREATE VIEW {table} AS
SELECT *
- FROM READ_CSV('{DIR}{table}.csv.gz', delim='|', header=True, columns={columns})
+ FROM READ_CSV('{DIR_TPCH}{table}.csv.gz', delim='|', header=True, columns={columns})
+ """
+ )
+
+ for table, columns in TPCDS_SCHEMA.items():
+ cls.tpcds_conn.execute(
+ f"""
+ CREATE VIEW {table} AS
+ SELECT *
+ FROM READ_CSV('{DIR_TPCDS}{table}.csv.gz', delim='|', header=True, columns={columns})
"""
)
cls.cache = {}
- cls.sqls = [
- (sql, expected)
- for _, sql, expected in load_sql_fixture_pairs("optimizer/tpc-h/tpc-h.sql")
- ]
+ cls.tpch_sqls = list(load_sql_fixture_pairs("optimizer/tpc-h/tpc-h.sql"))
+ cls.tpcds_sqls = list(load_sql_fixture_pairs("optimizer/tpc-ds/tpc-ds.sql"))
@classmethod
def tearDownClass(cls):
- cls.conn.close()
+ cls.tpch_conn.close()
+ cls.tpcds_conn.close()
- def cached_execute(self, sql):
+ def cached_execute(self, sql, tpch=True):
+ conn = self.tpch_conn if tpch else self.tpcds_conn
if sql not in self.cache:
- self.cache[sql] = self.conn.execute(transpile(sql, write="duckdb")[0]).fetchdf()
+ self.cache[sql] = conn.execute(transpile(sql, write="duckdb")[0]).fetchdf()
return self.cache[sql]
def rename_anonymous(self, source, target):
@@ -66,18 +80,28 @@ class TestExecutor(unittest.TestCase):
self.assertEqual(generate(parse_one("x is null")), "scope[None][x] is None")
def test_optimized_tpch(self):
- for i, (sql, optimized) in enumerate(self.sqls, start=1):
+ for i, (_, sql, optimized) in enumerate(self.tpch_sqls, start=1):
with self.subTest(f"{i}, {sql}"):
- a = self.cached_execute(sql)
- b = self.conn.execute(transpile(optimized, write="duckdb")[0]).fetchdf()
+ a = self.cached_execute(sql, tpch=True)
+ b = self.tpch_conn.execute(transpile(optimized, write="duckdb")[0]).fetchdf()
self.rename_anonymous(b, a)
assert_frame_equal(a, b)
+ def subtestHelper(self, i, table, tpch=True):
+ with self.subTest(f"{'tpc-h' if tpch else 'tpc-ds'} {i + 1}"):
+ _, sql, _ = self.tpch_sqls[i] if tpch else self.tpcds_sqls[i]
+ a = self.cached_execute(sql, tpch=tpch)
+ b = pd.DataFrame(
+ ((np.nan if c is None else c for c in r) for r in table.rows),
+ columns=table.columns,
+ )
+ assert_frame_equal(a, b, check_dtype=False, check_index_type=False)
+
def test_execute_tpch(self):
def to_csv(expression):
if isinstance(expression, exp.Table) and expression.name not in ("revenue"):
return parse_one(
- f"READ_CSV('{DIR}{expression.name}.csv.gz', 'delimiter', '|') AS {expression.alias_or_name}"
+ f"READ_CSV('{DIR_TPCH}{expression.name}.csv.gz', 'delimiter', '|') AS {expression.alias_or_name}"
)
return expression
@@ -87,19 +111,26 @@ class TestExecutor(unittest.TestCase):
execute,
(
(parse_one(sql).transform(to_csv).sql(pretty=True), TPCH_SCHEMA)
- for sql, _ in self.sqls
+ for _, sql, _ in self.tpch_sqls
),
)
):
- with self.subTest(f"tpch-h {i + 1}"):
- sql, _ = self.sqls[i]
- a = self.cached_execute(sql)
- b = pd.DataFrame(
- ((np.nan if c is None else c for c in r) for r in table.rows),
- columns=table.columns,
- )
-
- assert_frame_equal(a, b, check_dtype=False, check_index_type=False)
+ self.subtestHelper(i, table, tpch=True)
+
+ def test_execute_tpcds(self):
+ def to_csv(expression):
+ if isinstance(expression, exp.Table) and os.path.exists(
+ f"{DIR_TPCDS}{expression.name}.csv.gz"
+ ):
+ return parse_one(
+ f"READ_CSV('{DIR_TPCDS}{expression.name}.csv.gz', 'delimiter', '|') AS {expression.alias_or_name}"
+ )
+ return expression
+
+ for i, (meta, sql, _) in enumerate(self.tpcds_sqls):
+ if string_to_bool(meta.get("execute")):
+ table = execute(parse_one(sql).transform(to_csv).sql(pretty=True), TPCDS_SCHEMA)
+ self.subtestHelper(i, table, tpch=False)
def test_execute_callable(self):
tables = {