summaryrefslogtreecommitdiffstats
path: root/tests/test_optimizer.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_optimizer.py')
-rw-r--r--tests/test_optimizer.py33
1 files changed, 15 insertions, 18 deletions
diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py
index 40540b3..102e141 100644
--- a/tests/test_optimizer.py
+++ b/tests/test_optimizer.py
@@ -31,9 +31,7 @@ class TestOptimizer(unittest.TestCase):
dialect = meta.get("dialect")
with self.subTest(sql):
self.assertEqual(
- func(parse_one(sql, read=dialect), **kwargs).sql(
- pretty=pretty, dialect=dialect
- ),
+ func(parse_one(sql, read=dialect), **kwargs).sql(pretty=pretty, dialect=dialect),
expected,
)
@@ -86,9 +84,7 @@ class TestOptimizer(unittest.TestCase):
for sql in load_sql_fixtures("optimizer/qualify_columns__invalid.sql"):
with self.subTest(sql):
with self.assertRaises(OptimizeError):
- optimizer.qualify_columns.qualify_columns(
- parse_one(sql), schema=self.schema
- )
+ optimizer.qualify_columns.qualify_columns(parse_one(sql), schema=self.schema)
def test_quote_identities(self):
self.check_file("quote_identities", optimizer.quote_identities.quote_identities)
@@ -100,9 +96,7 @@ class TestOptimizer(unittest.TestCase):
expression = optimizer.pushdown_projections.pushdown_projections(expression)
return expression
- self.check_file(
- "pushdown_projections", pushdown_projections, schema=self.schema
- )
+ self.check_file("pushdown_projections", pushdown_projections, schema=self.schema)
def test_simplify(self):
self.check_file("simplify", optimizer.simplify.simplify)
@@ -115,9 +109,7 @@ class TestOptimizer(unittest.TestCase):
)
def test_pushdown_predicates(self):
- self.check_file(
- "pushdown_predicates", optimizer.pushdown_predicates.pushdown_predicates
- )
+ self.check_file("pushdown_predicates", optimizer.pushdown_predicates.pushdown_predicates)
def test_expand_multi_table_selects(self):
self.check_file(
@@ -138,10 +130,17 @@ class TestOptimizer(unittest.TestCase):
pretty=True,
)
+ def test_merge_derived_tables(self):
+ def optimize(expression, **kwargs):
+ expression = optimizer.qualify_tables.qualify_tables(expression)
+ expression = optimizer.qualify_columns.qualify_columns(expression, **kwargs)
+ expression = optimizer.merge_derived_tables.merge_derived_tables(expression)
+ return expression
+
+ self.check_file("merge_derived_tables", optimize, schema=self.schema)
+
def test_tpch(self):
- self.check_file(
- "tpc-h/tpc-h", optimizer.optimize, schema=TPCH_SCHEMA, pretty=True
- )
+ self.check_file("tpc-h/tpc-h", optimizer.optimize, schema=TPCH_SCHEMA, pretty=True)
def test_schema(self):
schema = ensure_schema(
@@ -262,9 +261,7 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
self.assertEqual(len(scopes), 5)
self.assertEqual(scopes[0].expression.sql(), "SELECT x.b FROM x")
self.assertEqual(scopes[1].expression.sql(), "SELECT y.b FROM y")
- self.assertEqual(
- scopes[2].expression.sql(), "SELECT MAX(x.a) FROM x WHERE x.b = s.b"
- )
+ self.assertEqual(scopes[2].expression.sql(), "SELECT MAX(x.a) FROM x WHERE x.b = s.b")
self.assertEqual(scopes[3].expression.sql(), "SELECT y.c AS b FROM y")
self.assertEqual(scopes[4].expression.sql(), parse_one(sql).sql())