1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
|
import unittest
import duckdb
import pandas as pd
from pandas.testing import assert_frame_equal
from sqlglot import exp, parse_one
from sqlglot.executor import execute
from sqlglot.executor.python import Python
from tests.helpers import (
FIXTURES_DIR,
SKIP_INTEGRATION,
TPCH_SCHEMA,
load_sql_fixture_pairs,
)
DIR = FIXTURES_DIR + "/optimizer/tpc-h/"
@unittest.skipIf(SKIP_INTEGRATION, "Skipping Integration Tests since `SKIP_INTEGRATION` is set")
class TestExecutor(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.conn = duckdb.connect()
for table in TPCH_SCHEMA:
cls.conn.execute(
f"""
CREATE VIEW {table} AS
SELECT *
FROM READ_CSV_AUTO('{DIR}{table}.csv.gz')
"""
)
cls.cache = {}
cls.sqls = [
(sql, expected)
for _, sql, expected in load_sql_fixture_pairs("optimizer/tpc-h/tpc-h.sql")
]
@classmethod
def tearDownClass(cls):
cls.conn.close()
def cached_execute(self, sql):
if sql not in self.cache:
self.cache[sql] = self.conn.execute(sql).fetchdf()
return self.cache[sql]
def rename_anonymous(self, source, target):
for i, column in enumerate(source.columns):
if "_col_" in column:
source.rename(columns={column: target.columns[i]}, inplace=True)
def test_py_dialect(self):
self.assertEqual(Python().generate(parse_one("'x '''")), r"'x \''")
def test_optimized_tpch(self):
for i, (sql, optimized) in enumerate(self.sqls[:20], start=1):
with self.subTest(f"{i}, {sql}"):
a = self.cached_execute(sql)
b = self.conn.execute(optimized).fetchdf()
self.rename_anonymous(b, a)
assert_frame_equal(a, b)
def test_execute_tpch(self):
def to_csv(expression):
if isinstance(expression, exp.Table):
return parse_one(
f"READ_CSV('{DIR}{expression.name}.csv.gz', 'delimiter', '|') AS {expression.name}"
)
return expression
for sql, _ in self.sqls[0:3]:
a = self.cached_execute(sql)
sql = parse_one(sql).transform(to_csv).sql(pretty=True)
table = execute(sql, TPCH_SCHEMA)
b = pd.DataFrame(table.rows, columns=table.columns)
assert_frame_equal(a, b, check_dtype=False)
|