diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-08 08:11:53 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-08 08:12:02 +0000 |
commit | 8d36f5966675e23bee7026ba37ae0647fbf47300 (patch) | |
tree | df4227bbb3b07cb70df87237bcff03c8efd7822d /sqlglot/optimizer/scope.py | |
parent | Releasing debian version 22.2.0-1. (diff) | |
download | sqlglot-8d36f5966675e23bee7026ba37ae0647fbf47300.tar.xz sqlglot-8d36f5966675e23bee7026ba37ae0647fbf47300.zip |
Merging upstream version 23.7.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/optimizer/scope.py')
-rw-r--r-- | sqlglot/optimizer/scope.py | 165 |
1 files changed, 91 insertions, 74 deletions
diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index 443fa6c..073ced2 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -8,7 +8,7 @@ from enum import Enum, auto from sqlglot import exp from sqlglot.errors import OptimizeError -from sqlglot.helper import ensure_collection, find_new_name +from sqlglot.helper import ensure_collection, find_new_name, seq_get logger = logging.getLogger("sqlglot") @@ -38,11 +38,11 @@ class Scope: SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; The LATERAL VIEW EXPLODE gets x as a source. cte_sources (dict[str, Scope]): Sources from CTES - outer_column_list (list[str]): If this is a derived table or CTE, and the outer query - defines a column list of it's alias of this scope, this is that list of columns. + outer_columns (list[str]): If this is a derived table or CTE, and the outer query + defines a column list for the alias of this scope, this is that list of columns. For example: SELECT * FROM (SELECT ...) AS y(col1, col2) - The inner query would have `["col1", "col2"]` for its `outer_column_list` + The inner query would have `["col1", "col2"]` for its `outer_columns` parent (Scope): Parent scope scope_type (ScopeType): Type of this scope, relative to it's parent subquery_scopes (list[Scope]): List of all child scopes for subqueries @@ -58,7 +58,7 @@ class Scope: self, expression, sources=None, - outer_column_list=None, + outer_columns=None, parent=None, scope_type=ScopeType.ROOT, lateral_sources=None, @@ -70,7 +70,7 @@ class Scope: self.cte_sources = cte_sources or {} self.sources.update(self.lateral_sources) self.sources.update(self.cte_sources) - self.outer_column_list = outer_column_list or [] + self.outer_columns = outer_columns or [] self.parent = parent self.scope_type = scope_type self.subquery_scopes = [] @@ -119,10 +119,11 @@ class Scope: self._raw_columns = [] self._join_hints = [] - for node, parent, _ in self.walk(bfs=False): + for node in self.walk(bfs=False): if node is self.expression: continue - elif isinstance(node, exp.Column) and not isinstance(node.this, exp.Star): + + if isinstance(node, exp.Column) and not isinstance(node.this, exp.Star): self._raw_columns.append(node) elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint): self._tables.append(node) @@ -132,10 +133,8 @@ class Scope: self._udtfs.append(node) elif isinstance(node, exp.CTE): self._ctes.append(node) - elif ( - isinstance(node, exp.Subquery) - and isinstance(parent, (exp.From, exp.Join, exp.Subquery)) - and _is_derived_table(node) + elif _is_derived_table(node) and isinstance( + node.parent, (exp.From, exp.Join, exp.Subquery) ): self._derived_tables.append(node) elif isinstance(node, exp.UNWRAPPED_QUERIES): @@ -438,11 +437,21 @@ class Scope: Yields: Scope: scope instances in depth-first-search post-order """ - for child_scope in itertools.chain( - self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes - ): - yield from child_scope.traverse() - yield self + stack = [self] + result = [] + while stack: + scope = stack.pop() + result.append(scope) + stack.extend( + itertools.chain( + scope.cte_scopes, + scope.union_scopes, + scope.table_scopes, + scope.subquery_scopes, + ) + ) + + yield from reversed(result) def ref_count(self): """ @@ -481,14 +490,28 @@ def traverse_scope(expression: exp.Expression) -> t.List[Scope]: ('SELECT a FROM (SELECT a FROM x) AS y', ['y']) Args: - expression (exp.Expression): expression to traverse + expression: Expression to traverse Returns: - list[Scope]: scope instances + A list of the created scope instances """ - if isinstance(expression, exp.Query) or ( - isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Query) - ): + if isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Query): + # We ignore the DDL expression and build a scope for its query instead + ddl_with = expression.args.get("with") + expression = expression.expression + + # If the DDL has CTEs attached, we need to add them to the query, or + # prepend them if the query itself already has CTEs attached to it + if ddl_with: + ddl_with.pop() + query_ctes = expression.ctes + if not query_ctes: + expression.set("with", ddl_with) + else: + expression.args["with"].set("recursive", ddl_with.recursive) + expression.args["with"].set("expressions", [*ddl_with.expressions, *query_ctes]) + + if isinstance(expression, exp.Query): return list(_traverse_scope(Scope(expression))) return [] @@ -499,21 +522,21 @@ def build_scope(expression: exp.Expression) -> t.Optional[Scope]: Build a scope tree. Args: - expression (exp.Expression): expression to build the scope tree for + expression: Expression to build the scope tree for. + Returns: - Scope: root scope + The root scope """ - scopes = traverse_scope(expression) - if scopes: - return scopes[-1] - return None + return seq_get(traverse_scope(expression), -1) def _traverse_scope(scope): if isinstance(scope.expression, exp.Select): yield from _traverse_select(scope) elif isinstance(scope.expression, exp.Union): + yield from _traverse_ctes(scope) yield from _traverse_union(scope) + return elif isinstance(scope.expression, exp.Subquery): if scope.is_root: yield from _traverse_select(scope) @@ -523,8 +546,6 @@ def _traverse_scope(scope): yield from _traverse_tables(scope) elif isinstance(scope.expression, exp.UDTF): yield from _traverse_udtfs(scope) - elif isinstance(scope.expression, exp.DDL): - yield from _traverse_ddl(scope) else: logger.warning( "Cannot traverse scope %s with type '%s'", scope.expression, type(scope.expression) @@ -541,30 +562,38 @@ def _traverse_select(scope): def _traverse_union(scope): - yield from _traverse_ctes(scope) + prev_scope = None + union_scope_stack = [scope] + expression_stack = [scope.expression.right, scope.expression.left] - # The last scope to be yield should be the top most scope - left = None - for left in _traverse_scope( - scope.branch( - scope.expression.left, - outer_column_list=scope.outer_column_list, - scope_type=ScopeType.UNION, - ) - ): - yield left + while expression_stack: + expression = expression_stack.pop() + union_scope = union_scope_stack[-1] - right = None - for right in _traverse_scope( - scope.branch( - scope.expression.right, - outer_column_list=scope.outer_column_list, + new_scope = union_scope.branch( + expression, + outer_columns=union_scope.outer_columns, scope_type=ScopeType.UNION, ) - ): - yield right - scope.union_scopes = [left, right] + if isinstance(expression, exp.Union): + yield from _traverse_ctes(new_scope) + + union_scope_stack.append(new_scope) + expression_stack.extend([expression.right, expression.left]) + continue + + for scope in _traverse_scope(new_scope): + yield scope + + if prev_scope: + union_scope_stack.pop() + union_scope.union_scopes = [prev_scope, scope] + prev_scope = union_scope + + yield union_scope + else: + prev_scope = scope def _traverse_ctes(scope): @@ -588,7 +617,7 @@ def _traverse_ctes(scope): scope.branch( cte.this, cte_sources=sources, - outer_column_list=cte.alias_column_names, + outer_columns=cte.alias_column_names, scope_type=ScopeType.CTE, ) ): @@ -615,7 +644,9 @@ def _is_derived_table(expression: exp.Subquery) -> bool: as it doesn't introduce a new scope. If an alias is present, it shadows all names under the Subquery, so that's one exception to this rule. """ - return bool(expression.alias or isinstance(expression.this, exp.UNWRAPPED_QUERIES)) + return isinstance(expression, exp.Subquery) and bool( + expression.alias or isinstance(expression.this, exp.UNWRAPPED_QUERIES) + ) def _traverse_tables(scope): @@ -681,7 +712,7 @@ def _traverse_tables(scope): scope.branch( expression, lateral_sources=lateral_sources, - outer_column_list=expression.alias_column_names, + outer_columns=expression.alias_column_names, scope_type=scope_type, ) ): @@ -719,13 +750,13 @@ def _traverse_udtfs(scope): sources = {} for expression in expressions: - if isinstance(expression, exp.Subquery) and _is_derived_table(expression): + if _is_derived_table(expression): top = None for child_scope in _traverse_scope( scope.branch( expression, scope_type=ScopeType.DERIVED_TABLE, - outer_column_list=expression.alias_column_names, + outer_columns=expression.alias_column_names, ) ): yield child_scope @@ -738,18 +769,6 @@ def _traverse_udtfs(scope): scope.sources.update(sources) -def _traverse_ddl(scope): - yield from _traverse_ctes(scope) - - query_scope = scope.branch( - scope.expression.expression, scope_type=ScopeType.DERIVED_TABLE, sources=scope.sources - ) - query_scope._collect() - query_scope._ctes = scope.ctes + query_scope._ctes - - yield from _traverse_scope(query_scope) - - def walk_in_scope(expression, bfs=True, prune=None): """ Returns a generator object which visits all nodes in the syntrax tree, stopping at @@ -769,23 +788,21 @@ def walk_in_scope(expression, bfs=True, prune=None): # Whenever we set it to True, we exclude a subtree from traversal. crossed_scope_boundary = False - for node, parent, key in expression.walk( - bfs=bfs, prune=lambda *args: crossed_scope_boundary or (prune and prune(*args)) + for node in expression.walk( + bfs=bfs, prune=lambda n: crossed_scope_boundary or (prune and prune(n)) ): crossed_scope_boundary = False - yield node, parent, key + yield node if node is expression: continue if ( isinstance(node, exp.CTE) or ( - isinstance(node, exp.Subquery) - and isinstance(parent, (exp.From, exp.Join, exp.Subquery)) - and _is_derived_table(node) + isinstance(node.parent, (exp.From, exp.Join, exp.Subquery)) + and (_is_derived_table(node) or isinstance(node, exp.UDTF)) ) - or isinstance(node, exp.UDTF) or isinstance(node, exp.UNWRAPPED_QUERIES) ): crossed_scope_boundary = True @@ -812,7 +829,7 @@ def find_all_in_scope(expression, expression_types, bfs=True): Yields: exp.Expression: nodes """ - for expression, *_ in walk_in_scope(expression, bfs=bfs): + for expression in walk_in_scope(expression, bfs=bfs): if isinstance(expression, tuple(ensure_collection(expression_types))): yield expression |