From 90150543f9314be683d22a16339effd774192f6d Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Thu, 22 Sep 2022 06:31:28 +0200 Subject: Merging upstream version 6.1.1. Signed-off-by: Daniel Baumann --- sqlglot/optimizer/scope.py | 58 ++++++++++++++++++---------------------------- 1 file changed, 22 insertions(+), 36 deletions(-) (limited to 'sqlglot/optimizer/scope.py') diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index f6f59e8..e816e10 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -104,9 +104,7 @@ class Scope: elif isinstance(node, exp.CTE): self._ctes.append(node) prune = True - elif isinstance(node, exp.Subquery) and isinstance( - parent, (exp.From, exp.Join) - ): + elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)): self._derived_tables.append(node) prune = True elif isinstance(node, exp.Subqueryable): @@ -195,20 +193,14 @@ class Scope: self._ensure_collected() columns = self._raw_columns - external_columns = [ - column - for scope in self.subquery_scopes - for column in scope.external_columns - ] + external_columns = [column for scope in self.subquery_scopes for column in scope.external_columns] named_outputs = {e.alias_or_name for e in self.expression.expressions} self._columns = [ c for c in columns + external_columns - if not ( - c.find_ancestor(exp.Qualify, exp.Order) and c.name in named_outputs - ) + if not (c.find_ancestor(exp.Qualify, exp.Order) and not c.table and c.name in named_outputs) ] return self._columns @@ -229,9 +221,7 @@ class Scope: for table in self.tables: referenced_names.append( ( - table.parent.alias - if isinstance(table.parent, exp.Alias) - else table.name, + table.parent.alias if isinstance(table.parent, exp.Alias) else table.name, table, ) ) @@ -274,9 +264,7 @@ class Scope: sources in the current scope. """ if self._external_columns is None: - self._external_columns = [ - c for c in self.columns if c.table not in self.selected_sources - ] + self._external_columns = [c for c in self.columns if c.table not in self.selected_sources] return self._external_columns def source_columns(self, source_name): @@ -310,6 +298,16 @@ class Scope: columns = self.sources.pop(old_name or "", []) self.sources[new_name] = columns + def add_source(self, name, source): + """Add a source to this scope""" + self.sources[name] = source + self.clear_cache() + + def remove_source(self, name): + """Remove a source from this scope""" + self.sources.pop(name, None) + self.clear_cache() + def traverse_scope(expression): """ @@ -334,7 +332,7 @@ def traverse_scope(expression): Args: expression (exp.Expression): expression to traverse Returns: - List[Scope]: scope instances + list[Scope]: scope instances """ return list(_traverse_scope(Scope(expression))) @@ -356,9 +354,7 @@ def _traverse_scope(scope): def _traverse_select(scope): yield from _traverse_derived_tables(scope.ctes, scope, ScopeType.CTE) yield from _traverse_subqueries(scope) - yield from _traverse_derived_tables( - scope.derived_tables, scope, ScopeType.DERIVED_TABLE - ) + yield from _traverse_derived_tables(scope.derived_tables, scope, ScopeType.DERIVED_TABLE) _add_table_sources(scope) @@ -367,15 +363,11 @@ def _traverse_union(scope): # The last scope to be yield should be the top most scope left = None - for left in _traverse_scope( - scope.branch(scope.expression.left, scope_type=ScopeType.UNION) - ): + for left in _traverse_scope(scope.branch(scope.expression.left, scope_type=ScopeType.UNION)): yield left right = None - for right in _traverse_scope( - scope.branch(scope.expression.right, scope_type=ScopeType.UNION) - ): + for right in _traverse_scope(scope.branch(scope.expression.right, scope_type=ScopeType.UNION)): yield right scope.union = (left, right) @@ -387,14 +379,10 @@ def _traverse_derived_tables(derived_tables, scope, scope_type): for derived_table in derived_tables: for child_scope in _traverse_scope( scope.branch( - derived_table - if isinstance(derived_table, (exp.Unnest, exp.Lateral)) - else derived_table.this, + derived_table if isinstance(derived_table, (exp.Unnest, exp.Lateral)) else derived_table.this, add_sources=sources if scope_type == ScopeType.CTE else None, outer_column_list=derived_table.alias_column_names, - scope_type=ScopeType.UNNEST - if isinstance(derived_table, exp.Unnest) - else scope_type, + scope_type=ScopeType.UNNEST if isinstance(derived_table, exp.Unnest) else scope_type, ) ): yield child_scope @@ -430,9 +418,7 @@ def _add_table_sources(scope): def _traverse_subqueries(scope): for subquery in scope.subqueries: top = None - for child_scope in _traverse_scope( - scope.branch(subquery, scope_type=ScopeType.SUBQUERY) - ): + for child_scope in _traverse_scope(scope.branch(subquery, scope_type=ScopeType.SUBQUERY)): yield child_scope top = child_scope scope.subquery_scopes.append(top) -- cgit v1.2.3