From 7b29f6168bf9fcb2d886447066a9bb51675e5665 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Tue, 4 Oct 2022 11:37:14 +0200 Subject: Merging upstream version 6.2.8. Signed-off-by: Daniel Baumann --- sqlglot/optimizer/scope.py | 122 ++++++++++++++++++++++++++++++++++----------- 1 file changed, 94 insertions(+), 28 deletions(-) (limited to 'sqlglot/optimizer/scope.py') diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index be6cfb9..6332cdd 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -1,5 +1,4 @@ import itertools -from copy import copy from enum import Enum, auto from sqlglot import exp @@ -12,7 +11,7 @@ class ScopeType(Enum): DERIVED_TABLE = auto() CTE = auto() UNION = auto() - UNNEST = auto() + UDTF = auto() class Scope: @@ -70,14 +69,11 @@ class Scope: self._columns = None self._external_columns = None - def branch(self, expression, scope_type, add_sources=None, **kwargs): + def branch(self, expression, scope_type, chain_sources=None, **kwargs): """Branch from the current scope to a new, inner scope""" - sources = copy(self.sources) - if add_sources: - sources.update(add_sources) return Scope( expression=expression.unnest(), - sources=sources, + sources={**self.cte_sources, **(chain_sources or {})}, parent=self, scope_type=scope_type, **kwargs, @@ -90,30 +86,21 @@ class Scope: self._derived_tables = [] self._raw_columns = [] - # We'll use this variable to pass state into the dfs generator. - # Whenever we set it to True, we exclude a subtree from traversal. - prune = False - - for node, parent, _ in self.expression.dfs(prune=lambda *_: prune): - prune = False - + for node, parent, _ in self.walk(bfs=False): if node is self.expression: continue - if isinstance(node, exp.Column) and not isinstance(node.this, exp.Star): + elif isinstance(node, exp.Column) and not isinstance(node.this, exp.Star): self._raw_columns.append(node) elif isinstance(node, exp.Table): self._tables.append(node) - elif isinstance(node, (exp.Unnest, exp.Lateral)): + elif isinstance(node, exp.UDTF): self._derived_tables.append(node) elif isinstance(node, exp.CTE): self._ctes.append(node) - prune = True elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)): self._derived_tables.append(node) - prune = True elif isinstance(node, exp.Subqueryable): self._subqueries.append(node) - prune = True self._collected = True @@ -121,6 +108,43 @@ class Scope: if not self._collected: self._collect() + def walk(self, bfs=True): + return walk_in_scope(self.expression, bfs=bfs) + + def find(self, *expression_types, bfs=True): + """ + Returns the first node in this scope which matches at least one of the specified types. + + This does NOT traverse into subscopes. + + Args: + expression_types (type): the expression type(s) to match. + bfs (bool): True to use breadth-first search, False to use depth-first. + + Returns: + exp.Expression: the node which matches the criteria or None if no node matching + the criteria was found. + """ + return next(self.find_all(*expression_types, bfs=bfs), None) + + def find_all(self, *expression_types, bfs=True): + """ + Returns a generator object which visits all nodes in this scope and only yields those that + match at least one of the specified expression types. + + This does NOT traverse into subscopes. + + Args: + expression_types (type): the expression type(s) to match. + bfs (bool): True to use breadth-first search, False to use depth-first. + + Yields: + exp.Expression: nodes + """ + for expression, _, _ in self.walk(bfs=bfs): + if isinstance(expression, expression_types): + yield expression + def replace(self, old, new): """ Replace `old` with `new`. @@ -246,6 +270,16 @@ class Scope: self._selected_sources = result return self._selected_sources + @property + def cte_sources(self): + """ + Sources that are CTEs. + + Returns: + dict[str, Scope]: Mapping of source alias to Scope + """ + return {alias: scope for alias, scope in self.sources.items() if isinstance(scope, Scope) and scope.is_cte} + @property def selects(self): """ @@ -313,9 +347,9 @@ class Scope: return self.scope_type == ScopeType.ROOT @property - def is_unnest(self): - """Determine if this scope is an unnest""" - return self.scope_type == ScopeType.UNNEST + def is_udtf(self): + """Determine if this scope is a UDTF (User Defined Table Function)""" + return self.scope_type == ScopeType.UDTF @property def is_correlated_subquery(self): @@ -348,7 +382,7 @@ class Scope: Scope: scope instances in depth-first-search post-order """ for child_scope in itertools.chain( - self.cte_scopes, self.union_scopes, self.subquery_scopes, self.derived_table_scopes + self.cte_scopes, self.union_scopes, self.derived_table_scopes, self.subquery_scopes ): yield from child_scope.traverse() yield self @@ -399,7 +433,7 @@ def _traverse_scope(scope): yield from _traverse_select(scope) elif isinstance(scope.expression, exp.Union): yield from _traverse_union(scope) - elif isinstance(scope.expression, (exp.Lateral, exp.Unnest)): + elif isinstance(scope.expression, exp.UDTF): pass elif isinstance(scope.expression, exp.Subquery): yield from _traverse_subqueries(scope) @@ -410,8 +444,8 @@ def _traverse_scope(scope): def _traverse_select(scope): yield from _traverse_derived_tables(scope.ctes, scope, ScopeType.CTE) - yield from _traverse_subqueries(scope) yield from _traverse_derived_tables(scope.derived_tables, scope, ScopeType.DERIVED_TABLE) + yield from _traverse_subqueries(scope) _add_table_sources(scope) @@ -437,10 +471,10 @@ def _traverse_derived_tables(derived_tables, scope, scope_type): top = None for child_scope in _traverse_scope( scope.branch( - derived_table if isinstance(derived_table, (exp.Unnest, exp.Lateral)) else derived_table.this, - add_sources=sources if scope_type == ScopeType.CTE else None, + derived_table if isinstance(derived_table, exp.UDTF) else derived_table.this, + chain_sources=sources if scope_type == ScopeType.CTE else None, outer_column_list=derived_table.alias_column_names, - scope_type=ScopeType.UNNEST if isinstance(derived_table, exp.Unnest) else scope_type, + scope_type=ScopeType.UDTF if isinstance(derived_table, exp.UDTF) else scope_type, ) ): yield child_scope @@ -483,3 +517,35 @@ def _traverse_subqueries(scope): yield child_scope top = child_scope scope.subquery_scopes.append(top) + + +def walk_in_scope(expression, bfs=True): + """ + Returns a generator object which visits all nodes in the syntrax tree, stopping at + nodes that start child scopes. + + Args: + expression (exp.Expression): + bfs (bool): if set to True the BFS traversal order will be applied, + otherwise the DFS traversal will be used instead. + + Yields: + tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key + """ + # We'll use this variable to pass state into the dfs generator. + # Whenever we set it to True, we exclude a subtree from traversal. + prune = False + + for node, parent, key in expression.walk(bfs=bfs, prune=lambda *_: prune): + prune = False + + yield node, parent, key + + if node is expression: + continue + elif isinstance(node, exp.CTE): + prune = True + elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)): + prune = True + elif isinstance(node, exp.Subqueryable): + prune = True -- cgit v1.2.3