diff options
Diffstat (limited to 'tests/test_optimizer.py')
-rw-r--r-- | tests/test_optimizer.py | 125 |
1 files changed, 94 insertions, 31 deletions
diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 102e141..8d4aecc 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -1,9 +1,11 @@ import unittest +from functools import partial -from sqlglot import optimizer, parse_one, table +from sqlglot import exp, optimizer, parse_one, table from sqlglot.errors import OptimizeError +from sqlglot.optimizer.annotate_types import annotate_types from sqlglot.optimizer.schema import MappingSchema, ensure_schema -from sqlglot.optimizer.scope import traverse_scope +from sqlglot.optimizer.scope import build_scope, traverse_scope from tests.helpers import TPCH_SCHEMA, load_sql_fixture_pairs, load_sql_fixtures @@ -27,11 +29,17 @@ class TestOptimizer(unittest.TestCase): } def check_file(self, file, func, pretty=False, **kwargs): - for meta, sql, expected in load_sql_fixture_pairs(f"optimizer/{file}.sql"): + for i, (meta, sql, expected) in enumerate(load_sql_fixture_pairs(f"optimizer/{file}.sql"), start=1): dialect = meta.get("dialect") - with self.subTest(sql): + leave_tables_isolated = meta.get("leave_tables_isolated") + + func_kwargs = {**kwargs} + if leave_tables_isolated is not None: + func_kwargs["leave_tables_isolated"] = leave_tables_isolated.lower() in ("true", "1") + + with self.subTest(f"{i}, {sql}"): self.assertEqual( - func(parse_one(sql, read=dialect), **kwargs).sql(pretty=pretty, dialect=dialect), + func(parse_one(sql, read=dialect), **func_kwargs).sql(pretty=pretty, dialect=dialect), expected, ) @@ -123,21 +131,20 @@ class TestOptimizer(unittest.TestCase): optimizer.optimize_joins.optimize_joins, ) - def test_eliminate_subqueries(self): - self.check_file( - "eliminate_subqueries", - optimizer.eliminate_subqueries.eliminate_subqueries, - pretty=True, + def test_merge_subqueries(self): + optimize = partial( + optimizer.optimize, + rules=[ + optimizer.qualify_tables.qualify_tables, + optimizer.qualify_columns.qualify_columns, + optimizer.merge_subqueries.merge_subqueries, + ], ) - def test_merge_derived_tables(self): - def optimize(expression, **kwargs): - expression = optimizer.qualify_tables.qualify_tables(expression) - expression = optimizer.qualify_columns.qualify_columns(expression, **kwargs) - expression = optimizer.merge_derived_tables.merge_derived_tables(expression) - return expression + self.check_file("merge_subqueries", optimize, schema=self.schema) - self.check_file("merge_derived_tables", optimize, schema=self.schema) + def test_eliminate_subqueries(self): + self.check_file("eliminate_subqueries", optimizer.eliminate_subqueries.eliminate_subqueries) def test_tpch(self): self.check_file("tpc-h/tpc-h", optimizer.optimize, schema=TPCH_SCHEMA, pretty=True) @@ -257,17 +264,73 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') ON s.b = r.b WHERE s.b > (SELECT MAX(x.a) FROM x WHERE x.b = s.b) """ - scopes = traverse_scope(parse_one(sql)) - self.assertEqual(len(scopes), 5) - self.assertEqual(scopes[0].expression.sql(), "SELECT x.b FROM x") - self.assertEqual(scopes[1].expression.sql(), "SELECT y.b FROM y") - self.assertEqual(scopes[2].expression.sql(), "SELECT MAX(x.a) FROM x WHERE x.b = s.b") - self.assertEqual(scopes[3].expression.sql(), "SELECT y.c AS b FROM y") - self.assertEqual(scopes[4].expression.sql(), parse_one(sql).sql()) - - self.assertEqual(set(scopes[4].sources), {"q", "r", "s"}) - self.assertEqual(len(scopes[4].columns), 6) - self.assertEqual(set(c.table for c in scopes[4].columns), {"r", "s"}) - self.assertEqual(scopes[4].source_columns("q"), []) - self.assertEqual(len(scopes[4].source_columns("r")), 2) - self.assertEqual(set(c.table for c in scopes[4].source_columns("r")), {"r"}) + for scopes in traverse_scope(parse_one(sql)), list(build_scope(parse_one(sql)).traverse()): + self.assertEqual(len(scopes), 5) + self.assertEqual(scopes[0].expression.sql(), "SELECT x.b FROM x") + self.assertEqual(scopes[1].expression.sql(), "SELECT y.b FROM y") + self.assertEqual(scopes[2].expression.sql(), "SELECT MAX(x.a) FROM x WHERE x.b = s.b") + self.assertEqual(scopes[3].expression.sql(), "SELECT y.c AS b FROM y") + self.assertEqual(scopes[4].expression.sql(), parse_one(sql).sql()) + + self.assertEqual(set(scopes[4].sources), {"q", "r", "s"}) + self.assertEqual(len(scopes[4].columns), 6) + self.assertEqual(set(c.table for c in scopes[4].columns), {"r", "s"}) + self.assertEqual(scopes[4].source_columns("q"), []) + self.assertEqual(len(scopes[4].source_columns("r")), 2) + self.assertEqual(set(c.table for c in scopes[4].source_columns("r")), {"r"}) + + def test_literal_type_annotation(self): + tests = { + "SELECT 5": exp.DataType.Type.INT, + "SELECT 5.3": exp.DataType.Type.DOUBLE, + "SELECT 'bla'": exp.DataType.Type.VARCHAR, + "5": exp.DataType.Type.INT, + "5.3": exp.DataType.Type.DOUBLE, + "'bla'": exp.DataType.Type.VARCHAR, + } + + for sql, target_type in tests.items(): + expression = parse_one(sql) + annotated_expression = annotate_types(expression) + + self.assertEqual(annotated_expression.find(exp.Literal).type, target_type) + + def test_boolean_type_annotation(self): + tests = { + "SELECT TRUE": exp.DataType.Type.BOOLEAN, + "FALSE": exp.DataType.Type.BOOLEAN, + } + + for sql, target_type in tests.items(): + expression = parse_one(sql) + annotated_expression = annotate_types(expression) + + self.assertEqual(annotated_expression.find(exp.Boolean).type, target_type) + + def test_cast_type_annotation(self): + expression = parse_one("CAST('2020-01-01' AS TIMESTAMPTZ(9))") + annotate_types(expression) + + self.assertEqual(expression.type, exp.DataType.Type.TIMESTAMPTZ) + self.assertEqual(expression.this.type, exp.DataType.Type.VARCHAR) + self.assertEqual(expression.args["to"].type, exp.DataType.Type.TIMESTAMPTZ) + self.assertEqual(expression.args["to"].expressions[0].type, exp.DataType.Type.INT) + + def test_cache_annotation(self): + expression = parse_one("CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') AS SELECT 1") + annotated_expression = annotate_types(expression) + + self.assertEqual(annotated_expression.expression.expressions[0].type, exp.DataType.Type.INT) + + def test_binary_annotation(self): + expression = parse_one("SELECT 0.0 + (2 + 3)") + annotate_types(expression) + + expression = expression.expressions[0] + + self.assertEqual(expression.type, exp.DataType.Type.DOUBLE) + self.assertEqual(expression.left.type, exp.DataType.Type.DOUBLE) + self.assertEqual(expression.right.type, exp.DataType.Type.INT) + self.assertEqual(expression.right.this.type, exp.DataType.Type.INT) + self.assertEqual(expression.right.this.left.type, exp.DataType.Type.INT) + self.assertEqual(expression.right.this.right.type, exp.DataType.Type.INT) |