summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/scope.py
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2023-09-07 11:39:48 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2023-09-07 11:39:48 +0000
commitf73e9af131151f1e058446361c35b05c4c90bf10 (patch)
treeed425b89f12d3f5e4709290bdc03d876f365bc97 /sqlglot/optimizer/scope.py
parentReleasing debian version 17.12.0-1. (diff)
downloadsqlglot-f73e9af131151f1e058446361c35b05c4c90bf10.tar.xz
sqlglot-f73e9af131151f1e058446361c35b05c4c90bf10.zip
Merging upstream version 18.2.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/optimizer/scope.py')
-rw-r--r--sqlglot/optimizer/scope.py72
1 files changed, 41 insertions, 31 deletions
diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py
index fb12384..435899a 100644
--- a/sqlglot/optimizer/scope.py
+++ b/sqlglot/optimizer/scope.py
@@ -6,7 +6,7 @@ from enum import Enum, auto
from sqlglot import exp
from sqlglot.errors import OptimizeError
-from sqlglot.helper import find_new_name
+from sqlglot.helper import ensure_collection, find_new_name
logger = logging.getLogger("sqlglot")
@@ -141,38 +141,10 @@ class Scope:
return walk_in_scope(self.expression, bfs=bfs)
def find(self, *expression_types, bfs=True):
- """
- Returns the first node in this scope which matches at least one of the specified types.
-
- This does NOT traverse into subscopes.
-
- Args:
- expression_types (type): the expression type(s) to match.
- bfs (bool): True to use breadth-first search, False to use depth-first.
-
- Returns:
- exp.Expression: the node which matches the criteria or None if no node matching
- the criteria was found.
- """
- return next(self.find_all(*expression_types, bfs=bfs), None)
+ return find_in_scope(self.expression, expression_types, bfs=bfs)
def find_all(self, *expression_types, bfs=True):
- """
- Returns a generator object which visits all nodes in this scope and only yields those that
- match at least one of the specified expression types.
-
- This does NOT traverse into subscopes.
-
- Args:
- expression_types (type): the expression type(s) to match.
- bfs (bool): True to use breadth-first search, False to use depth-first.
-
- Yields:
- exp.Expression: nodes
- """
- for expression, *_ in self.walk(bfs=bfs):
- if isinstance(expression, expression_types):
- yield expression
+ return find_all_in_scope(self.expression, expression_types, bfs=bfs)
def replace(self, old, new):
"""
@@ -800,3 +772,41 @@ def walk_in_scope(expression, bfs=True):
for key in ("joins", "laterals", "pivots"):
for arg in node.args.get(key) or []:
yield from walk_in_scope(arg, bfs=bfs)
+
+
+def find_all_in_scope(expression, expression_types, bfs=True):
+ """
+ Returns a generator object which visits all nodes in this scope and only yields those that
+ match at least one of the specified expression types.
+
+ This does NOT traverse into subscopes.
+
+ Args:
+ expression (exp.Expression):
+ expression_types (tuple[type]|type): the expression type(s) to match.
+ bfs (bool): True to use breadth-first search, False to use depth-first.
+
+ Yields:
+ exp.Expression: nodes
+ """
+ for expression, *_ in walk_in_scope(expression, bfs=bfs):
+ if isinstance(expression, tuple(ensure_collection(expression_types))):
+ yield expression
+
+
+def find_in_scope(expression, expression_types, bfs=True):
+ """
+ Returns the first node in this scope which matches at least one of the specified types.
+
+ This does NOT traverse into subscopes.
+
+ Args:
+ expression (exp.Expression):
+ expression_types (tuple[type]|type): the expression type(s) to match.
+ bfs (bool): True to use breadth-first search, False to use depth-first.
+
+ Returns:
+ exp.Expression: the node which matches the criteria or None if no node matching
+ the criteria was found.
+ """
+ return next(find_all_in_scope(expression, expression_types, bfs=bfs), None)