diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-09-07 11:39:48 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-09-07 11:39:48 +0000 |
commit | f73e9af131151f1e058446361c35b05c4c90bf10 (patch) | |
tree | ed425b89f12d3f5e4709290bdc03d876f365bc97 /sqlglot/optimizer/scope.py | |
parent | Releasing debian version 17.12.0-1. (diff) | |
download | sqlglot-f73e9af131151f1e058446361c35b05c4c90bf10.tar.xz sqlglot-f73e9af131151f1e058446361c35b05c4c90bf10.zip |
Merging upstream version 18.2.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/optimizer/scope.py')
-rw-r--r-- | sqlglot/optimizer/scope.py | 72 |
1 files changed, 41 insertions, 31 deletions
diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index fb12384..435899a 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -6,7 +6,7 @@ from enum import Enum, auto from sqlglot import exp from sqlglot.errors import OptimizeError -from sqlglot.helper import find_new_name +from sqlglot.helper import ensure_collection, find_new_name logger = logging.getLogger("sqlglot") @@ -141,38 +141,10 @@ class Scope: return walk_in_scope(self.expression, bfs=bfs) def find(self, *expression_types, bfs=True): - """ - Returns the first node in this scope which matches at least one of the specified types. - - This does NOT traverse into subscopes. - - Args: - expression_types (type): the expression type(s) to match. - bfs (bool): True to use breadth-first search, False to use depth-first. - - Returns: - exp.Expression: the node which matches the criteria or None if no node matching - the criteria was found. - """ - return next(self.find_all(*expression_types, bfs=bfs), None) + return find_in_scope(self.expression, expression_types, bfs=bfs) def find_all(self, *expression_types, bfs=True): - """ - Returns a generator object which visits all nodes in this scope and only yields those that - match at least one of the specified expression types. - - This does NOT traverse into subscopes. - - Args: - expression_types (type): the expression type(s) to match. - bfs (bool): True to use breadth-first search, False to use depth-first. - - Yields: - exp.Expression: nodes - """ - for expression, *_ in self.walk(bfs=bfs): - if isinstance(expression, expression_types): - yield expression + return find_all_in_scope(self.expression, expression_types, bfs=bfs) def replace(self, old, new): """ @@ -800,3 +772,41 @@ def walk_in_scope(expression, bfs=True): for key in ("joins", "laterals", "pivots"): for arg in node.args.get(key) or []: yield from walk_in_scope(arg, bfs=bfs) + + +def find_all_in_scope(expression, expression_types, bfs=True): + """ + Returns a generator object which visits all nodes in this scope and only yields those that + match at least one of the specified expression types. + + This does NOT traverse into subscopes. + + Args: + expression (exp.Expression): + expression_types (tuple[type]|type): the expression type(s) to match. + bfs (bool): True to use breadth-first search, False to use depth-first. + + Yields: + exp.Expression: nodes + """ + for expression, *_ in walk_in_scope(expression, bfs=bfs): + if isinstance(expression, tuple(ensure_collection(expression_types))): + yield expression + + +def find_in_scope(expression, expression_types, bfs=True): + """ + Returns the first node in this scope which matches at least one of the specified types. + + This does NOT traverse into subscopes. + + Args: + expression (exp.Expression): + expression_types (tuple[type]|type): the expression type(s) to match. + bfs (bool): True to use breadth-first search, False to use depth-first. + + Returns: + exp.Expression: the node which matches the criteria or None if no node matching + the criteria was found. + """ + return next(find_all_in_scope(expression, expression_types, bfs=bfs), None) |