diff options
Diffstat (limited to 'sqlglot/optimizer/scope.py')
-rw-r--r-- | sqlglot/optimizer/scope.py | 18 |
1 files changed, 11 insertions, 7 deletions
diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index 435899a..4af5b49 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -137,8 +137,8 @@ class Scope: if not self._collected: self._collect() - def walk(self, bfs=True): - return walk_in_scope(self.expression, bfs=bfs) + def walk(self, bfs=True, prune=None): + return walk_in_scope(self.expression, bfs=bfs, prune=None) def find(self, *expression_types, bfs=True): return find_in_scope(self.expression, expression_types, bfs=bfs) @@ -731,7 +731,7 @@ def _traverse_ddl(scope): yield from _traverse_scope(query_scope) -def walk_in_scope(expression, bfs=True): +def walk_in_scope(expression, bfs=True, prune=None): """ Returns a generator object which visits all nodes in the syntrax tree, stopping at nodes that start child scopes. @@ -740,16 +740,20 @@ def walk_in_scope(expression, bfs=True): expression (exp.Expression): bfs (bool): if set to True the BFS traversal order will be applied, otherwise the DFS traversal will be used instead. + prune ((node, parent, arg_key) -> bool): callable that returns True if + the generator should stop traversing this branch of the tree. Yields: tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key """ # We'll use this variable to pass state into the dfs generator. # Whenever we set it to True, we exclude a subtree from traversal. - prune = False + crossed_scope_boundary = False - for node, parent, key in expression.walk(bfs=bfs, prune=lambda *_: prune): - prune = False + for node, parent, key in expression.walk( + bfs=bfs, prune=lambda *args: crossed_scope_boundary or (prune and prune(*args)) + ): + crossed_scope_boundary = False yield node, parent, key @@ -765,7 +769,7 @@ def walk_in_scope(expression, bfs=True): or isinstance(node, exp.UDTF) or isinstance(node, exp.Subqueryable) ): - prune = True + crossed_scope_boundary = True if isinstance(node, (exp.Subquery, exp.UDTF)): # The following args are not actually in the inner scope, so we should visit them |