diff options
Diffstat (limited to 'tests/test_optimizer.py')
-rw-r--r-- | tests/test_optimizer.py | 119 |
1 files changed, 23 insertions, 96 deletions
diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index a67e9db..3b5990f 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -5,11 +5,11 @@ import duckdb from pandas.testing import assert_frame_equal import sqlglot -from sqlglot import exp, optimizer, parse_one, table +from sqlglot import exp, optimizer, parse_one from sqlglot.errors import OptimizeError from sqlglot.optimizer.annotate_types import annotate_types -from sqlglot.optimizer.schema import MappingSchema, ensure_schema from sqlglot.optimizer.scope import build_scope, traverse_scope, walk_in_scope +from sqlglot.schema import MappingSchema from tests.helpers import ( TPCH_SCHEMA, load_sql_fixture_pairs, @@ -29,19 +29,19 @@ class TestOptimizer(unittest.TestCase): CREATE TABLE x (a INT, b INT); CREATE TABLE y (b INT, c INT); CREATE TABLE z (b INT, c INT); - + INSERT INTO x VALUES (1, 1); INSERT INTO x VALUES (2, 2); INSERT INTO x VALUES (2, 2); INSERT INTO x VALUES (3, 3); INSERT INTO x VALUES (null, null); - + INSERT INTO y VALUES (2, 2); INSERT INTO y VALUES (2, 2); INSERT INTO y VALUES (3, 3); INSERT INTO y VALUES (4, 4); INSERT INTO y VALUES (null, null); - + INSERT INTO y VALUES (3, 3); INSERT INTO y VALUES (3, 3); INSERT INTO y VALUES (4, 4); @@ -80,8 +80,8 @@ class TestOptimizer(unittest.TestCase): with self.subTest(title): self.assertEqual( - optimized.sql(pretty=pretty, dialect=dialect), expected, + optimized.sql(pretty=pretty, dialect=dialect), ) should_execute = meta.get("execute") @@ -223,85 +223,6 @@ class TestOptimizer(unittest.TestCase): def test_tpch(self): self.check_file("tpc-h/tpc-h", optimizer.optimize, schema=TPCH_SCHEMA, pretty=True) - def test_schema(self): - schema = ensure_schema( - { - "x": { - "a": "uint64", - } - } - ) - self.assertEqual( - schema.column_names( - table( - "x", - ) - ), - ["a"], - ) - with self.assertRaises(ValueError): - schema.column_names(table("x", db="db", catalog="c")) - with self.assertRaises(ValueError): - schema.column_names(table("x", db="db")) - with self.assertRaises(ValueError): - schema.column_names(table("x2")) - - schema = ensure_schema( - { - "db": { - "x": { - "a": "uint64", - } - } - } - ) - self.assertEqual(schema.column_names(table("x", db="db")), ["a"]) - with self.assertRaises(ValueError): - schema.column_names(table("x", db="db", catalog="c")) - with self.assertRaises(ValueError): - schema.column_names(table("x")) - with self.assertRaises(ValueError): - schema.column_names(table("x", db="db2")) - with self.assertRaises(ValueError): - schema.column_names(table("x2", db="db")) - - schema = ensure_schema( - { - "c": { - "db": { - "x": { - "a": "uint64", - } - } - } - } - ) - self.assertEqual(schema.column_names(table("x", db="db", catalog="c")), ["a"]) - with self.assertRaises(ValueError): - schema.column_names(table("x", db="db")) - with self.assertRaises(ValueError): - schema.column_names(table("x")) - with self.assertRaises(ValueError): - schema.column_names(table("x", db="db", catalog="c2")) - with self.assertRaises(ValueError): - schema.column_names(table("x", db="db2")) - with self.assertRaises(ValueError): - schema.column_names(table("x2", db="db")) - - schema = ensure_schema( - MappingSchema( - { - "x": { - "a": "uint64", - } - } - ) - ) - self.assertEqual(schema.column_names(table("x")), ["a"]) - - with self.assertRaises(OptimizeError): - ensure_schema({}) - def test_file_schema(self): expression = parse_one( """ @@ -327,6 +248,8 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') SELECT x.b FROM x ), r AS ( SELECT y.b FROM y + ), z as ( + SELECT cola, colb FROM (VALUES(1, 'test')) AS tab(cola, colb) ) SELECT r.b, @@ -340,19 +263,23 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') """ expression = parse_one(sql) for scopes in traverse_scope(expression), list(build_scope(expression).traverse()): - self.assertEqual(len(scopes), 5) + self.assertEqual(len(scopes), 7) self.assertEqual(scopes[0].expression.sql(), "SELECT x.b FROM x") self.assertEqual(scopes[1].expression.sql(), "SELECT y.b FROM y") - self.assertEqual(scopes[2].expression.sql(), "SELECT y.c AS b FROM y") - self.assertEqual(scopes[3].expression.sql(), "SELECT MAX(x.a) FROM x WHERE x.b = s.b") - self.assertEqual(scopes[4].expression.sql(), parse_one(sql).sql()) - - self.assertEqual(set(scopes[4].sources), {"q", "r", "s"}) - self.assertEqual(len(scopes[4].columns), 6) - self.assertEqual(set(c.table for c in scopes[4].columns), {"r", "s"}) - self.assertEqual(scopes[4].source_columns("q"), []) - self.assertEqual(len(scopes[4].source_columns("r")), 2) - self.assertEqual(set(c.table for c in scopes[4].source_columns("r")), {"r"}) + self.assertEqual(scopes[2].expression.sql(), "(VALUES (1, 'test')) AS tab(cola, colb)") + self.assertEqual( + scopes[3].expression.sql(), "SELECT cola, colb FROM (VALUES (1, 'test')) AS tab(cola, colb)" + ) + self.assertEqual(scopes[4].expression.sql(), "SELECT y.c AS b FROM y") + self.assertEqual(scopes[5].expression.sql(), "SELECT MAX(x.a) FROM x WHERE x.b = s.b") + self.assertEqual(scopes[6].expression.sql(), parse_one(sql).sql()) + + self.assertEqual(set(scopes[6].sources), {"q", "z", "r", "s"}) + self.assertEqual(len(scopes[6].columns), 6) + self.assertEqual(set(c.table for c in scopes[6].columns), {"r", "s"}) + self.assertEqual(scopes[6].source_columns("q"), []) + self.assertEqual(len(scopes[6].source_columns("r")), 2) + self.assertEqual(set(c.table for c in scopes[6].source_columns("r")), {"r"}) self.assertEqual({c.sql() for c in scopes[-1].find_all(exp.Column)}, {"r.b", "s.b"}) self.assertEqual(scopes[-1].find(exp.Column).sql(), "r.b") |