diff options
Diffstat (limited to 'sqlglot/optimizer/scope.py')
-rw-r--r-- | sqlglot/optimizer/scope.py | 72 |
1 files changed, 60 insertions, 12 deletions
diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index a7dab35..fb12384 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -435,7 +435,10 @@ class Scope: @property def is_correlated_subquery(self): """Determine if this scope is a correlated subquery""" - return bool(self.is_subquery and self.external_columns) + return bool( + (self.is_subquery or (self.parent and isinstance(self.parent.expression, exp.Lateral))) + and self.external_columns + ) def rename_source(self, old_name, new_name): """Rename a source in this scope""" @@ -486,7 +489,7 @@ class Scope: def traverse_scope(expression: exp.Expression) -> t.List[Scope]: """ - Traverse an expression by it's "scopes". + Traverse an expression by its "scopes". "Scope" represents the current context of a Select statement. @@ -509,9 +512,12 @@ def traverse_scope(expression: exp.Expression) -> t.List[Scope]: Returns: list[Scope]: scope instances """ - if not isinstance(expression, exp.Unionable): - return [] - return list(_traverse_scope(Scope(expression))) + if isinstance(expression, exp.Unionable) or ( + isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Subqueryable) + ): + return list(_traverse_scope(Scope(expression))) + + return [] def build_scope(expression: exp.Expression) -> t.Optional[Scope]: @@ -539,7 +545,9 @@ def _traverse_scope(scope): elif isinstance(scope.expression, exp.Table): yield from _traverse_tables(scope) elif isinstance(scope.expression, exp.UDTF): - pass + yield from _traverse_udtfs(scope) + elif isinstance(scope.expression, exp.DDL): + yield from _traverse_ddl(scope) else: logger.warning( "Cannot traverse scope %s with type '%s'", scope.expression, type(scope.expression) @@ -576,10 +584,10 @@ def _traverse_ctes(scope): for cte in scope.ctes: recursive_scope = None - # if the scope is a recursive cte, it must be in the form of - # base_case UNION recursive. thus the recursive scope is the first - # section of the union. - if scope.expression.args["with"].recursive: + # if the scope is a recursive cte, it must be in the form of base_case UNION recursive. + # thus the recursive scope is the first section of the union. + with_ = scope.expression.args.get("with") + if with_ and with_.recursive: union = cte.this if isinstance(union, exp.Union): @@ -692,8 +700,7 @@ def _traverse_tables(scope): # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything. # Until then, this means that only a single, unaliased derived table is allowed (rather, # the latest one wins. - alias = expression.alias - sources[alias] = child_scope + sources[expression.alias] = child_scope # append the final child_scope yielded scopes.append(child_scope) @@ -711,6 +718,47 @@ def _traverse_subqueries(scope): scope.subquery_scopes.append(top) +def _traverse_udtfs(scope): + if isinstance(scope.expression, exp.Unnest): + expressions = scope.expression.expressions + elif isinstance(scope.expression, exp.Lateral): + expressions = [scope.expression.this] + else: + expressions = [] + + sources = {} + for expression in expressions: + if isinstance(expression, exp.Subquery) and _is_derived_table(expression): + top = None + for child_scope in _traverse_scope( + scope.branch( + expression, + scope_type=ScopeType.DERIVED_TABLE, + outer_column_list=expression.alias_column_names, + ) + ): + yield child_scope + top = child_scope + sources[expression.alias] = child_scope + + scope.derived_table_scopes.append(top) + scope.table_scopes.append(top) + + scope.sources.update(sources) + + +def _traverse_ddl(scope): + yield from _traverse_ctes(scope) + + query_scope = scope.branch( + scope.expression.expression, scope_type=ScopeType.DERIVED_TABLE, chain_sources=scope.sources + ) + query_scope._collect() + query_scope._ctes = scope.ctes + query_scope._ctes + + yield from _traverse_scope(query_scope) + + def walk_in_scope(expression, bfs=True): """ Returns a generator object which visits all nodes in the syntrax tree, stopping at |