From f1c2dbe3b17a0d5edffbb65b85b642d0bb2756c5 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Tue, 19 Dec 2023 12:01:55 +0100 Subject: Merging upstream version 20.3.0. Signed-off-by: Daniel Baumann --- sqlglot/optimizer/scope.py | 39 +++++++++++++++++++-------------------- 1 file changed, 19 insertions(+), 20 deletions(-) (limited to 'sqlglot/optimizer/scope.py') diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index b7e527e..d34857d 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -37,6 +37,7 @@ class Scope: For example: SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; The LATERAL VIEW EXPLODE gets x as a source. + cte_sources (dict[str, Scope]): Sources from CTES outer_column_list (list[str]): If this is a derived table or CTE, and the outer query defines a column list of it's alias of this scope, this is that list of columns. For example: @@ -61,11 +62,14 @@ class Scope: parent=None, scope_type=ScopeType.ROOT, lateral_sources=None, + cte_sources=None, ): self.expression = expression self.sources = sources or {} - self.lateral_sources = lateral_sources.copy() if lateral_sources else {} + self.lateral_sources = lateral_sources or {} + self.cte_sources = cte_sources or {} self.sources.update(self.lateral_sources) + self.sources.update(self.cte_sources) self.outer_column_list = outer_column_list or [] self.parent = parent self.scope_type = scope_type @@ -92,13 +96,17 @@ class Scope: self._pivots = None self._references = None - def branch(self, expression, scope_type, chain_sources=None, **kwargs): + def branch( + self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs + ): """Branch from the current scope to a new, inner scope""" return Scope( expression=expression.unnest(), - sources={**self.cte_sources, **(chain_sources or {})}, + sources=sources.copy() if sources else None, parent=self, scope_type=scope_type, + cte_sources={**self.cte_sources, **(cte_sources or {})}, + lateral_sources=lateral_sources.copy() if lateral_sources else None, **kwargs, ) @@ -305,20 +313,6 @@ class Scope: return self._references - @property - def cte_sources(self): - """ - Sources that are CTEs. - - Returns: - dict[str, Scope]: Mapping of source alias to Scope - """ - return { - alias: scope - for alias, scope in self.sources.items() - if isinstance(scope, Scope) and scope.is_cte - } - @property def external_columns(self): """ @@ -515,7 +509,10 @@ def _traverse_scope(scope): elif isinstance(scope.expression, exp.Union): yield from _traverse_union(scope) elif isinstance(scope.expression, exp.Subquery): - yield from _traverse_subqueries(scope) + if scope.is_root: + yield from _traverse_select(scope) + else: + yield from _traverse_subqueries(scope) elif isinstance(scope.expression, exp.Table): yield from _traverse_tables(scope) elif isinstance(scope.expression, exp.UDTF): @@ -572,7 +569,7 @@ def _traverse_ctes(scope): for child_scope in _traverse_scope( scope.branch( cte.this, - chain_sources=sources, + cte_sources=sources, outer_column_list=cte.alias_column_names, scope_type=ScopeType.CTE, ) @@ -584,12 +581,14 @@ def _traverse_ctes(scope): if recursive_scope: child_scope.add_source(alias, recursive_scope) + child_scope.cte_sources[alias] = recursive_scope # append the final child_scope yielded if child_scope: scope.cte_scopes.append(child_scope) scope.sources.update(sources) + scope.cte_sources.update(sources) def _is_derived_table(expression: exp.Subquery) -> bool: @@ -725,7 +724,7 @@ def _traverse_ddl(scope): yield from _traverse_ctes(scope) query_scope = scope.branch( - scope.expression.expression, scope_type=ScopeType.DERIVED_TABLE, chain_sources=scope.sources + scope.expression.expression, scope_type=ScopeType.DERIVED_TABLE, sources=scope.sources ) query_scope._collect() query_scope._ctes = scope.ctes + query_scope._ctes -- cgit v1.2.3