diff options
Diffstat (limited to 'tests/test_optimizer.py')
-rw-r--r-- | tests/test_optimizer.py | 19 |
1 files changed, 15 insertions, 4 deletions
diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 8d4aecc..aad84ed 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -5,7 +5,7 @@ from sqlglot import exp, optimizer, parse_one, table from sqlglot.errors import OptimizeError from sqlglot.optimizer.annotate_types import annotate_types from sqlglot.optimizer.schema import MappingSchema, ensure_schema -from sqlglot.optimizer.scope import build_scope, traverse_scope +from sqlglot.optimizer.scope import build_scope, traverse_scope, walk_in_scope from tests.helpers import TPCH_SCHEMA, load_sql_fixture_pairs, load_sql_fixtures @@ -264,12 +264,13 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') ON s.b = r.b WHERE s.b > (SELECT MAX(x.a) FROM x WHERE x.b = s.b) """ - for scopes in traverse_scope(parse_one(sql)), list(build_scope(parse_one(sql)).traverse()): + expression = parse_one(sql) + for scopes in traverse_scope(expression), list(build_scope(expression).traverse()): self.assertEqual(len(scopes), 5) self.assertEqual(scopes[0].expression.sql(), "SELECT x.b FROM x") self.assertEqual(scopes[1].expression.sql(), "SELECT y.b FROM y") - self.assertEqual(scopes[2].expression.sql(), "SELECT MAX(x.a) FROM x WHERE x.b = s.b") - self.assertEqual(scopes[3].expression.sql(), "SELECT y.c AS b FROM y") + self.assertEqual(scopes[2].expression.sql(), "SELECT y.c AS b FROM y") + self.assertEqual(scopes[3].expression.sql(), "SELECT MAX(x.a) FROM x WHERE x.b = s.b") self.assertEqual(scopes[4].expression.sql(), parse_one(sql).sql()) self.assertEqual(set(scopes[4].sources), {"q", "r", "s"}) @@ -279,6 +280,16 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') self.assertEqual(len(scopes[4].source_columns("r")), 2) self.assertEqual(set(c.table for c in scopes[4].source_columns("r")), {"r"}) + self.assertEqual({c.sql() for c in scopes[-1].find_all(exp.Column)}, {"r.b", "s.b"}) + self.assertEqual(scopes[-1].find(exp.Column).sql(), "r.b") + self.assertEqual({c.sql() for c in scopes[0].find_all(exp.Column)}, {"x.b"}) + + # Check that we can walk in scope from an arbitrary node + self.assertEqual( + {node.sql() for node, *_ in walk_in_scope(expression.find(exp.Where)) if isinstance(node, exp.Column)}, + {"s.b"}, + ) + def test_literal_type_annotation(self): tests = { "SELECT 5": exp.DataType.Type.INT, |