summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/scope.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer/scope.py')
-rw-r--r--sqlglot/optimizer/scope.py122
1 files changed, 94 insertions, 28 deletions
diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py
index be6cfb9..6332cdd 100644
--- a/sqlglot/optimizer/scope.py
+++ b/sqlglot/optimizer/scope.py
@@ -1,5 +1,4 @@
import itertools
-from copy import copy
from enum import Enum, auto
from sqlglot import exp
@@ -12,7 +11,7 @@ class ScopeType(Enum):
DERIVED_TABLE = auto()
CTE = auto()
UNION = auto()
- UNNEST = auto()
+ UDTF = auto()
class Scope:
@@ -70,14 +69,11 @@ class Scope:
self._columns = None
self._external_columns = None
- def branch(self, expression, scope_type, add_sources=None, **kwargs):
+ def branch(self, expression, scope_type, chain_sources=None, **kwargs):
"""Branch from the current scope to a new, inner scope"""
- sources = copy(self.sources)
- if add_sources:
- sources.update(add_sources)
return Scope(
expression=expression.unnest(),
- sources=sources,
+ sources={**self.cte_sources, **(chain_sources or {})},
parent=self,
scope_type=scope_type,
**kwargs,
@@ -90,30 +86,21 @@ class Scope:
self._derived_tables = []
self._raw_columns = []
- # We'll use this variable to pass state into the dfs generator.
- # Whenever we set it to True, we exclude a subtree from traversal.
- prune = False
-
- for node, parent, _ in self.expression.dfs(prune=lambda *_: prune):
- prune = False
-
+ for node, parent, _ in self.walk(bfs=False):
if node is self.expression:
continue
- if isinstance(node, exp.Column) and not isinstance(node.this, exp.Star):
+ elif isinstance(node, exp.Column) and not isinstance(node.this, exp.Star):
self._raw_columns.append(node)
elif isinstance(node, exp.Table):
self._tables.append(node)
- elif isinstance(node, (exp.Unnest, exp.Lateral)):
+ elif isinstance(node, exp.UDTF):
self._derived_tables.append(node)
elif isinstance(node, exp.CTE):
self._ctes.append(node)
- prune = True
elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)):
self._derived_tables.append(node)
- prune = True
elif isinstance(node, exp.Subqueryable):
self._subqueries.append(node)
- prune = True
self._collected = True
@@ -121,6 +108,43 @@ class Scope:
if not self._collected:
self._collect()
+ def walk(self, bfs=True):
+ return walk_in_scope(self.expression, bfs=bfs)
+
+ def find(self, *expression_types, bfs=True):
+ """
+ Returns the first node in this scope which matches at least one of the specified types.
+
+ This does NOT traverse into subscopes.
+
+ Args:
+ expression_types (type): the expression type(s) to match.
+ bfs (bool): True to use breadth-first search, False to use depth-first.
+
+ Returns:
+ exp.Expression: the node which matches the criteria or None if no node matching
+ the criteria was found.
+ """
+ return next(self.find_all(*expression_types, bfs=bfs), None)
+
+ def find_all(self, *expression_types, bfs=True):
+ """
+ Returns a generator object which visits all nodes in this scope and only yields those that
+ match at least one of the specified expression types.
+
+ This does NOT traverse into subscopes.
+
+ Args:
+ expression_types (type): the expression type(s) to match.
+ bfs (bool): True to use breadth-first search, False to use depth-first.
+
+ Yields:
+ exp.Expression: nodes
+ """
+ for expression, _, _ in self.walk(bfs=bfs):
+ if isinstance(expression, expression_types):
+ yield expression
+
def replace(self, old, new):
"""
Replace `old` with `new`.
@@ -247,6 +271,16 @@ class Scope:
return self._selected_sources
@property
+ def cte_sources(self):
+ """
+ Sources that are CTEs.
+
+ Returns:
+ dict[str, Scope]: Mapping of source alias to Scope
+ """
+ return {alias: scope for alias, scope in self.sources.items() if isinstance(scope, Scope) and scope.is_cte}
+
+ @property
def selects(self):
"""
Select expressions of this scope.
@@ -313,9 +347,9 @@ class Scope:
return self.scope_type == ScopeType.ROOT
@property
- def is_unnest(self):
- """Determine if this scope is an unnest"""
- return self.scope_type == ScopeType.UNNEST
+ def is_udtf(self):
+ """Determine if this scope is a UDTF (User Defined Table Function)"""
+ return self.scope_type == ScopeType.UDTF
@property
def is_correlated_subquery(self):
@@ -348,7 +382,7 @@ class Scope:
Scope: scope instances in depth-first-search post-order
"""
for child_scope in itertools.chain(
- self.cte_scopes, self.union_scopes, self.subquery_scopes, self.derived_table_scopes
+ self.cte_scopes, self.union_scopes, self.derived_table_scopes, self.subquery_scopes
):
yield from child_scope.traverse()
yield self
@@ -399,7 +433,7 @@ def _traverse_scope(scope):
yield from _traverse_select(scope)
elif isinstance(scope.expression, exp.Union):
yield from _traverse_union(scope)
- elif isinstance(scope.expression, (exp.Lateral, exp.Unnest)):
+ elif isinstance(scope.expression, exp.UDTF):
pass
elif isinstance(scope.expression, exp.Subquery):
yield from _traverse_subqueries(scope)
@@ -410,8 +444,8 @@ def _traverse_scope(scope):
def _traverse_select(scope):
yield from _traverse_derived_tables(scope.ctes, scope, ScopeType.CTE)
- yield from _traverse_subqueries(scope)
yield from _traverse_derived_tables(scope.derived_tables, scope, ScopeType.DERIVED_TABLE)
+ yield from _traverse_subqueries(scope)
_add_table_sources(scope)
@@ -437,10 +471,10 @@ def _traverse_derived_tables(derived_tables, scope, scope_type):
top = None
for child_scope in _traverse_scope(
scope.branch(
- derived_table if isinstance(derived_table, (exp.Unnest, exp.Lateral)) else derived_table.this,
- add_sources=sources if scope_type == ScopeType.CTE else None,
+ derived_table if isinstance(derived_table, exp.UDTF) else derived_table.this,
+ chain_sources=sources if scope_type == ScopeType.CTE else None,
outer_column_list=derived_table.alias_column_names,
- scope_type=ScopeType.UNNEST if isinstance(derived_table, exp.Unnest) else scope_type,
+ scope_type=ScopeType.UDTF if isinstance(derived_table, exp.UDTF) else scope_type,
)
):
yield child_scope
@@ -483,3 +517,35 @@ def _traverse_subqueries(scope):
yield child_scope
top = child_scope
scope.subquery_scopes.append(top)
+
+
+def walk_in_scope(expression, bfs=True):
+ """
+ Returns a generator object which visits all nodes in the syntrax tree, stopping at
+ nodes that start child scopes.
+
+ Args:
+ expression (exp.Expression):
+ bfs (bool): if set to True the BFS traversal order will be applied,
+ otherwise the DFS traversal will be used instead.
+
+ Yields:
+ tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key
+ """
+ # We'll use this variable to pass state into the dfs generator.
+ # Whenever we set it to True, we exclude a subtree from traversal.
+ prune = False
+
+ for node, parent, key in expression.walk(bfs=bfs, prune=lambda *_: prune):
+ prune = False
+
+ yield node, parent, key
+
+ if node is expression:
+ continue
+ elif isinstance(node, exp.CTE):
+ prune = True
+ elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)):
+ prune = True
+ elif isinstance(node, exp.Subqueryable):
+ prune = True