summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/scope.py
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2023-12-19 11:01:55 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2023-12-19 11:01:55 +0000
commitf1c2dbe3b17a0d5edffbb65b85b642d0bb2756c5 (patch)
tree5dce0fe2a11381761496eb973c20750f44db56d5 /sqlglot/optimizer/scope.py
parentReleasing debian version 20.1.0-1. (diff)
downloadsqlglot-f1c2dbe3b17a0d5edffbb65b85b642d0bb2756c5.tar.xz
sqlglot-f1c2dbe3b17a0d5edffbb65b85b642d0bb2756c5.zip
Merging upstream version 20.3.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/optimizer/scope.py')
-rw-r--r--sqlglot/optimizer/scope.py39
1 files changed, 19 insertions, 20 deletions
diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py
index b7e527e..d34857d 100644
--- a/sqlglot/optimizer/scope.py
+++ b/sqlglot/optimizer/scope.py
@@ -37,6 +37,7 @@ class Scope:
For example:
SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c;
The LATERAL VIEW EXPLODE gets x as a source.
+ cte_sources (dict[str, Scope]): Sources from CTES
outer_column_list (list[str]): If this is a derived table or CTE, and the outer query
defines a column list of it's alias of this scope, this is that list of columns.
For example:
@@ -61,11 +62,14 @@ class Scope:
parent=None,
scope_type=ScopeType.ROOT,
lateral_sources=None,
+ cte_sources=None,
):
self.expression = expression
self.sources = sources or {}
- self.lateral_sources = lateral_sources.copy() if lateral_sources else {}
+ self.lateral_sources = lateral_sources or {}
+ self.cte_sources = cte_sources or {}
self.sources.update(self.lateral_sources)
+ self.sources.update(self.cte_sources)
self.outer_column_list = outer_column_list or []
self.parent = parent
self.scope_type = scope_type
@@ -92,13 +96,17 @@ class Scope:
self._pivots = None
self._references = None
- def branch(self, expression, scope_type, chain_sources=None, **kwargs):
+ def branch(
+ self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs
+ ):
"""Branch from the current scope to a new, inner scope"""
return Scope(
expression=expression.unnest(),
- sources={**self.cte_sources, **(chain_sources or {})},
+ sources=sources.copy() if sources else None,
parent=self,
scope_type=scope_type,
+ cte_sources={**self.cte_sources, **(cte_sources or {})},
+ lateral_sources=lateral_sources.copy() if lateral_sources else None,
**kwargs,
)
@@ -306,20 +314,6 @@ class Scope:
return self._references
@property
- def cte_sources(self):
- """
- Sources that are CTEs.
-
- Returns:
- dict[str, Scope]: Mapping of source alias to Scope
- """
- return {
- alias: scope
- for alias, scope in self.sources.items()
- if isinstance(scope, Scope) and scope.is_cte
- }
-
- @property
def external_columns(self):
"""
Columns that appear to reference sources in outer scopes.
@@ -515,7 +509,10 @@ def _traverse_scope(scope):
elif isinstance(scope.expression, exp.Union):
yield from _traverse_union(scope)
elif isinstance(scope.expression, exp.Subquery):
- yield from _traverse_subqueries(scope)
+ if scope.is_root:
+ yield from _traverse_select(scope)
+ else:
+ yield from _traverse_subqueries(scope)
elif isinstance(scope.expression, exp.Table):
yield from _traverse_tables(scope)
elif isinstance(scope.expression, exp.UDTF):
@@ -572,7 +569,7 @@ def _traverse_ctes(scope):
for child_scope in _traverse_scope(
scope.branch(
cte.this,
- chain_sources=sources,
+ cte_sources=sources,
outer_column_list=cte.alias_column_names,
scope_type=ScopeType.CTE,
)
@@ -584,12 +581,14 @@ def _traverse_ctes(scope):
if recursive_scope:
child_scope.add_source(alias, recursive_scope)
+ child_scope.cte_sources[alias] = recursive_scope
# append the final child_scope yielded
if child_scope:
scope.cte_scopes.append(child_scope)
scope.sources.update(sources)
+ scope.cte_sources.update(sources)
def _is_derived_table(expression: exp.Subquery) -> bool:
@@ -725,7 +724,7 @@ def _traverse_ddl(scope):
yield from _traverse_ctes(scope)
query_scope = scope.branch(
- scope.expression.expression, scope_type=ScopeType.DERIVED_TABLE, chain_sources=scope.sources
+ scope.expression.expression, scope_type=ScopeType.DERIVED_TABLE, sources=scope.sources
)
query_scope._collect()
query_scope._ctes = scope.ctes + query_scope._ctes