summaryrefslogtreecommitdiffstats
path: root/tests/test_optimizer.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_optimizer.py')
-rw-r--r--tests/test_optimizer.py125
1 files changed, 94 insertions, 31 deletions
diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py
index 102e141..8d4aecc 100644
--- a/tests/test_optimizer.py
+++ b/tests/test_optimizer.py
@@ -1,9 +1,11 @@
import unittest
+from functools import partial
-from sqlglot import optimizer, parse_one, table
+from sqlglot import exp, optimizer, parse_one, table
from sqlglot.errors import OptimizeError
+from sqlglot.optimizer.annotate_types import annotate_types
from sqlglot.optimizer.schema import MappingSchema, ensure_schema
-from sqlglot.optimizer.scope import traverse_scope
+from sqlglot.optimizer.scope import build_scope, traverse_scope
from tests.helpers import TPCH_SCHEMA, load_sql_fixture_pairs, load_sql_fixtures
@@ -27,11 +29,17 @@ class TestOptimizer(unittest.TestCase):
}
def check_file(self, file, func, pretty=False, **kwargs):
- for meta, sql, expected in load_sql_fixture_pairs(f"optimizer/{file}.sql"):
+ for i, (meta, sql, expected) in enumerate(load_sql_fixture_pairs(f"optimizer/{file}.sql"), start=1):
dialect = meta.get("dialect")
- with self.subTest(sql):
+ leave_tables_isolated = meta.get("leave_tables_isolated")
+
+ func_kwargs = {**kwargs}
+ if leave_tables_isolated is not None:
+ func_kwargs["leave_tables_isolated"] = leave_tables_isolated.lower() in ("true", "1")
+
+ with self.subTest(f"{i}, {sql}"):
self.assertEqual(
- func(parse_one(sql, read=dialect), **kwargs).sql(pretty=pretty, dialect=dialect),
+ func(parse_one(sql, read=dialect), **func_kwargs).sql(pretty=pretty, dialect=dialect),
expected,
)
@@ -123,21 +131,20 @@ class TestOptimizer(unittest.TestCase):
optimizer.optimize_joins.optimize_joins,
)
- def test_eliminate_subqueries(self):
- self.check_file(
- "eliminate_subqueries",
- optimizer.eliminate_subqueries.eliminate_subqueries,
- pretty=True,
+ def test_merge_subqueries(self):
+ optimize = partial(
+ optimizer.optimize,
+ rules=[
+ optimizer.qualify_tables.qualify_tables,
+ optimizer.qualify_columns.qualify_columns,
+ optimizer.merge_subqueries.merge_subqueries,
+ ],
)
- def test_merge_derived_tables(self):
- def optimize(expression, **kwargs):
- expression = optimizer.qualify_tables.qualify_tables(expression)
- expression = optimizer.qualify_columns.qualify_columns(expression, **kwargs)
- expression = optimizer.merge_derived_tables.merge_derived_tables(expression)
- return expression
+ self.check_file("merge_subqueries", optimize, schema=self.schema)
- self.check_file("merge_derived_tables", optimize, schema=self.schema)
+ def test_eliminate_subqueries(self):
+ self.check_file("eliminate_subqueries", optimizer.eliminate_subqueries.eliminate_subqueries)
def test_tpch(self):
self.check_file("tpc-h/tpc-h", optimizer.optimize, schema=TPCH_SCHEMA, pretty=True)
@@ -257,17 +264,73 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
ON s.b = r.b
WHERE s.b > (SELECT MAX(x.a) FROM x WHERE x.b = s.b)
"""
- scopes = traverse_scope(parse_one(sql))
- self.assertEqual(len(scopes), 5)
- self.assertEqual(scopes[0].expression.sql(), "SELECT x.b FROM x")
- self.assertEqual(scopes[1].expression.sql(), "SELECT y.b FROM y")
- self.assertEqual(scopes[2].expression.sql(), "SELECT MAX(x.a) FROM x WHERE x.b = s.b")
- self.assertEqual(scopes[3].expression.sql(), "SELECT y.c AS b FROM y")
- self.assertEqual(scopes[4].expression.sql(), parse_one(sql).sql())
-
- self.assertEqual(set(scopes[4].sources), {"q", "r", "s"})
- self.assertEqual(len(scopes[4].columns), 6)
- self.assertEqual(set(c.table for c in scopes[4].columns), {"r", "s"})
- self.assertEqual(scopes[4].source_columns("q"), [])
- self.assertEqual(len(scopes[4].source_columns("r")), 2)
- self.assertEqual(set(c.table for c in scopes[4].source_columns("r")), {"r"})
+ for scopes in traverse_scope(parse_one(sql)), list(build_scope(parse_one(sql)).traverse()):
+ self.assertEqual(len(scopes), 5)
+ self.assertEqual(scopes[0].expression.sql(), "SELECT x.b FROM x")
+ self.assertEqual(scopes[1].expression.sql(), "SELECT y.b FROM y")
+ self.assertEqual(scopes[2].expression.sql(), "SELECT MAX(x.a) FROM x WHERE x.b = s.b")
+ self.assertEqual(scopes[3].expression.sql(), "SELECT y.c AS b FROM y")
+ self.assertEqual(scopes[4].expression.sql(), parse_one(sql).sql())
+
+ self.assertEqual(set(scopes[4].sources), {"q", "r", "s"})
+ self.assertEqual(len(scopes[4].columns), 6)
+ self.assertEqual(set(c.table for c in scopes[4].columns), {"r", "s"})
+ self.assertEqual(scopes[4].source_columns("q"), [])
+ self.assertEqual(len(scopes[4].source_columns("r")), 2)
+ self.assertEqual(set(c.table for c in scopes[4].source_columns("r")), {"r"})
+
+ def test_literal_type_annotation(self):
+ tests = {
+ "SELECT 5": exp.DataType.Type.INT,
+ "SELECT 5.3": exp.DataType.Type.DOUBLE,
+ "SELECT 'bla'": exp.DataType.Type.VARCHAR,
+ "5": exp.DataType.Type.INT,
+ "5.3": exp.DataType.Type.DOUBLE,
+ "'bla'": exp.DataType.Type.VARCHAR,
+ }
+
+ for sql, target_type in tests.items():
+ expression = parse_one(sql)
+ annotated_expression = annotate_types(expression)
+
+ self.assertEqual(annotated_expression.find(exp.Literal).type, target_type)
+
+ def test_boolean_type_annotation(self):
+ tests = {
+ "SELECT TRUE": exp.DataType.Type.BOOLEAN,
+ "FALSE": exp.DataType.Type.BOOLEAN,
+ }
+
+ for sql, target_type in tests.items():
+ expression = parse_one(sql)
+ annotated_expression = annotate_types(expression)
+
+ self.assertEqual(annotated_expression.find(exp.Boolean).type, target_type)
+
+ def test_cast_type_annotation(self):
+ expression = parse_one("CAST('2020-01-01' AS TIMESTAMPTZ(9))")
+ annotate_types(expression)
+
+ self.assertEqual(expression.type, exp.DataType.Type.TIMESTAMPTZ)
+ self.assertEqual(expression.this.type, exp.DataType.Type.VARCHAR)
+ self.assertEqual(expression.args["to"].type, exp.DataType.Type.TIMESTAMPTZ)
+ self.assertEqual(expression.args["to"].expressions[0].type, exp.DataType.Type.INT)
+
+ def test_cache_annotation(self):
+ expression = parse_one("CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') AS SELECT 1")
+ annotated_expression = annotate_types(expression)
+
+ self.assertEqual(annotated_expression.expression.expressions[0].type, exp.DataType.Type.INT)
+
+ def test_binary_annotation(self):
+ expression = parse_one("SELECT 0.0 + (2 + 3)")
+ annotate_types(expression)
+
+ expression = expression.expressions[0]
+
+ self.assertEqual(expression.type, exp.DataType.Type.DOUBLE)
+ self.assertEqual(expression.left.type, exp.DataType.Type.DOUBLE)
+ self.assertEqual(expression.right.type, exp.DataType.Type.INT)
+ self.assertEqual(expression.right.this.type, exp.DataType.Type.INT)
+ self.assertEqual(expression.right.this.left.type, exp.DataType.Type.INT)
+ self.assertEqual(expression.right.this.right.type, exp.DataType.Type.INT)