summaryrefslogtreecommitdiffstats
path: root/tests/test_optimizer.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_optimizer.py')
-rw-r--r--tests/test_optimizer.py19
1 files changed, 15 insertions, 4 deletions
diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py
index 8d4aecc..aad84ed 100644
--- a/tests/test_optimizer.py
+++ b/tests/test_optimizer.py
@@ -5,7 +5,7 @@ from sqlglot import exp, optimizer, parse_one, table
from sqlglot.errors import OptimizeError
from sqlglot.optimizer.annotate_types import annotate_types
from sqlglot.optimizer.schema import MappingSchema, ensure_schema
-from sqlglot.optimizer.scope import build_scope, traverse_scope
+from sqlglot.optimizer.scope import build_scope, traverse_scope, walk_in_scope
from tests.helpers import TPCH_SCHEMA, load_sql_fixture_pairs, load_sql_fixtures
@@ -264,12 +264,13 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
ON s.b = r.b
WHERE s.b > (SELECT MAX(x.a) FROM x WHERE x.b = s.b)
"""
- for scopes in traverse_scope(parse_one(sql)), list(build_scope(parse_one(sql)).traverse()):
+ expression = parse_one(sql)
+ for scopes in traverse_scope(expression), list(build_scope(expression).traverse()):
self.assertEqual(len(scopes), 5)
self.assertEqual(scopes[0].expression.sql(), "SELECT x.b FROM x")
self.assertEqual(scopes[1].expression.sql(), "SELECT y.b FROM y")
- self.assertEqual(scopes[2].expression.sql(), "SELECT MAX(x.a) FROM x WHERE x.b = s.b")
- self.assertEqual(scopes[3].expression.sql(), "SELECT y.c AS b FROM y")
+ self.assertEqual(scopes[2].expression.sql(), "SELECT y.c AS b FROM y")
+ self.assertEqual(scopes[3].expression.sql(), "SELECT MAX(x.a) FROM x WHERE x.b = s.b")
self.assertEqual(scopes[4].expression.sql(), parse_one(sql).sql())
self.assertEqual(set(scopes[4].sources), {"q", "r", "s"})
@@ -279,6 +280,16 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
self.assertEqual(len(scopes[4].source_columns("r")), 2)
self.assertEqual(set(c.table for c in scopes[4].source_columns("r")), {"r"})
+ self.assertEqual({c.sql() for c in scopes[-1].find_all(exp.Column)}, {"r.b", "s.b"})
+ self.assertEqual(scopes[-1].find(exp.Column).sql(), "r.b")
+ self.assertEqual({c.sql() for c in scopes[0].find_all(exp.Column)}, {"x.b"})
+
+ # Check that we can walk in scope from an arbitrary node
+ self.assertEqual(
+ {node.sql() for node, *_ in walk_in_scope(expression.find(exp.Where)) if isinstance(node, exp.Column)},
+ {"s.b"},
+ )
+
def test_literal_type_annotation(self):
tests = {
"SELECT 5": exp.DataType.Type.INT,