summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/scope.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer/scope.py')
-rw-r--r--sqlglot/optimizer/scope.py18
1 files changed, 11 insertions, 7 deletions
diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py
index 435899a..4af5b49 100644
--- a/sqlglot/optimizer/scope.py
+++ b/sqlglot/optimizer/scope.py
@@ -137,8 +137,8 @@ class Scope:
if not self._collected:
self._collect()
- def walk(self, bfs=True):
- return walk_in_scope(self.expression, bfs=bfs)
+ def walk(self, bfs=True, prune=None):
+ return walk_in_scope(self.expression, bfs=bfs, prune=None)
def find(self, *expression_types, bfs=True):
return find_in_scope(self.expression, expression_types, bfs=bfs)
@@ -731,7 +731,7 @@ def _traverse_ddl(scope):
yield from _traverse_scope(query_scope)
-def walk_in_scope(expression, bfs=True):
+def walk_in_scope(expression, bfs=True, prune=None):
"""
Returns a generator object which visits all nodes in the syntrax tree, stopping at
nodes that start child scopes.
@@ -740,16 +740,20 @@ def walk_in_scope(expression, bfs=True):
expression (exp.Expression):
bfs (bool): if set to True the BFS traversal order will be applied,
otherwise the DFS traversal will be used instead.
+ prune ((node, parent, arg_key) -> bool): callable that returns True if
+ the generator should stop traversing this branch of the tree.
Yields:
tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key
"""
# We'll use this variable to pass state into the dfs generator.
# Whenever we set it to True, we exclude a subtree from traversal.
- prune = False
+ crossed_scope_boundary = False
- for node, parent, key in expression.walk(bfs=bfs, prune=lambda *_: prune):
- prune = False
+ for node, parent, key in expression.walk(
+ bfs=bfs, prune=lambda *args: crossed_scope_boundary or (prune and prune(*args))
+ ):
+ crossed_scope_boundary = False
yield node, parent, key
@@ -765,7 +769,7 @@ def walk_in_scope(expression, bfs=True):
or isinstance(node, exp.UDTF)
or isinstance(node, exp.Subqueryable)
):
- prune = True
+ crossed_scope_boundary = True
if isinstance(node, (exp.Subquery, exp.UDTF)):
# The following args are not actually in the inner scope, so we should visit them