summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/scope.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer/scope.py')
-rw-r--r--sqlglot/optimizer/scope.py72
1 files changed, 60 insertions, 12 deletions
diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py
index a7dab35..fb12384 100644
--- a/sqlglot/optimizer/scope.py
+++ b/sqlglot/optimizer/scope.py
@@ -435,7 +435,10 @@ class Scope:
@property
def is_correlated_subquery(self):
"""Determine if this scope is a correlated subquery"""
- return bool(self.is_subquery and self.external_columns)
+ return bool(
+ (self.is_subquery or (self.parent and isinstance(self.parent.expression, exp.Lateral)))
+ and self.external_columns
+ )
def rename_source(self, old_name, new_name):
"""Rename a source in this scope"""
@@ -486,7 +489,7 @@ class Scope:
def traverse_scope(expression: exp.Expression) -> t.List[Scope]:
"""
- Traverse an expression by it's "scopes".
+ Traverse an expression by its "scopes".
"Scope" represents the current context of a Select statement.
@@ -509,9 +512,12 @@ def traverse_scope(expression: exp.Expression) -> t.List[Scope]:
Returns:
list[Scope]: scope instances
"""
- if not isinstance(expression, exp.Unionable):
- return []
- return list(_traverse_scope(Scope(expression)))
+ if isinstance(expression, exp.Unionable) or (
+ isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Subqueryable)
+ ):
+ return list(_traverse_scope(Scope(expression)))
+
+ return []
def build_scope(expression: exp.Expression) -> t.Optional[Scope]:
@@ -539,7 +545,9 @@ def _traverse_scope(scope):
elif isinstance(scope.expression, exp.Table):
yield from _traverse_tables(scope)
elif isinstance(scope.expression, exp.UDTF):
- pass
+ yield from _traverse_udtfs(scope)
+ elif isinstance(scope.expression, exp.DDL):
+ yield from _traverse_ddl(scope)
else:
logger.warning(
"Cannot traverse scope %s with type '%s'", scope.expression, type(scope.expression)
@@ -576,10 +584,10 @@ def _traverse_ctes(scope):
for cte in scope.ctes:
recursive_scope = None
- # if the scope is a recursive cte, it must be in the form of
- # base_case UNION recursive. thus the recursive scope is the first
- # section of the union.
- if scope.expression.args["with"].recursive:
+ # if the scope is a recursive cte, it must be in the form of base_case UNION recursive.
+ # thus the recursive scope is the first section of the union.
+ with_ = scope.expression.args.get("with")
+ if with_ and with_.recursive:
union = cte.this
if isinstance(union, exp.Union):
@@ -692,8 +700,7 @@ def _traverse_tables(scope):
# This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything.
# Until then, this means that only a single, unaliased derived table is allowed (rather,
# the latest one wins.
- alias = expression.alias
- sources[alias] = child_scope
+ sources[expression.alias] = child_scope
# append the final child_scope yielded
scopes.append(child_scope)
@@ -711,6 +718,47 @@ def _traverse_subqueries(scope):
scope.subquery_scopes.append(top)
+def _traverse_udtfs(scope):
+ if isinstance(scope.expression, exp.Unnest):
+ expressions = scope.expression.expressions
+ elif isinstance(scope.expression, exp.Lateral):
+ expressions = [scope.expression.this]
+ else:
+ expressions = []
+
+ sources = {}
+ for expression in expressions:
+ if isinstance(expression, exp.Subquery) and _is_derived_table(expression):
+ top = None
+ for child_scope in _traverse_scope(
+ scope.branch(
+ expression,
+ scope_type=ScopeType.DERIVED_TABLE,
+ outer_column_list=expression.alias_column_names,
+ )
+ ):
+ yield child_scope
+ top = child_scope
+ sources[expression.alias] = child_scope
+
+ scope.derived_table_scopes.append(top)
+ scope.table_scopes.append(top)
+
+ scope.sources.update(sources)
+
+
+def _traverse_ddl(scope):
+ yield from _traverse_ctes(scope)
+
+ query_scope = scope.branch(
+ scope.expression.expression, scope_type=ScopeType.DERIVED_TABLE, chain_sources=scope.sources
+ )
+ query_scope._collect()
+ query_scope._ctes = scope.ctes + query_scope._ctes
+
+ yield from _traverse_scope(query_scope)
+
+
def walk_in_scope(expression, bfs=True):
"""
Returns a generator object which visits all nodes in the syntrax tree, stopping at