summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/scope.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer/scope.py')
-rw-r--r--sqlglot/optimizer/scope.py58
1 files changed, 22 insertions, 36 deletions
diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py
index f6f59e8..e816e10 100644
--- a/sqlglot/optimizer/scope.py
+++ b/sqlglot/optimizer/scope.py
@@ -104,9 +104,7 @@ class Scope:
elif isinstance(node, exp.CTE):
self._ctes.append(node)
prune = True
- elif isinstance(node, exp.Subquery) and isinstance(
- parent, (exp.From, exp.Join)
- ):
+ elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)):
self._derived_tables.append(node)
prune = True
elif isinstance(node, exp.Subqueryable):
@@ -195,20 +193,14 @@ class Scope:
self._ensure_collected()
columns = self._raw_columns
- external_columns = [
- column
- for scope in self.subquery_scopes
- for column in scope.external_columns
- ]
+ external_columns = [column for scope in self.subquery_scopes for column in scope.external_columns]
named_outputs = {e.alias_or_name for e in self.expression.expressions}
self._columns = [
c
for c in columns + external_columns
- if not (
- c.find_ancestor(exp.Qualify, exp.Order) and c.name in named_outputs
- )
+ if not (c.find_ancestor(exp.Qualify, exp.Order) and not c.table and c.name in named_outputs)
]
return self._columns
@@ -229,9 +221,7 @@ class Scope:
for table in self.tables:
referenced_names.append(
(
- table.parent.alias
- if isinstance(table.parent, exp.Alias)
- else table.name,
+ table.parent.alias if isinstance(table.parent, exp.Alias) else table.name,
table,
)
)
@@ -274,9 +264,7 @@ class Scope:
sources in the current scope.
"""
if self._external_columns is None:
- self._external_columns = [
- c for c in self.columns if c.table not in self.selected_sources
- ]
+ self._external_columns = [c for c in self.columns if c.table not in self.selected_sources]
return self._external_columns
def source_columns(self, source_name):
@@ -310,6 +298,16 @@ class Scope:
columns = self.sources.pop(old_name or "", [])
self.sources[new_name] = columns
+ def add_source(self, name, source):
+ """Add a source to this scope"""
+ self.sources[name] = source
+ self.clear_cache()
+
+ def remove_source(self, name):
+ """Remove a source from this scope"""
+ self.sources.pop(name, None)
+ self.clear_cache()
+
def traverse_scope(expression):
"""
@@ -334,7 +332,7 @@ def traverse_scope(expression):
Args:
expression (exp.Expression): expression to traverse
Returns:
- List[Scope]: scope instances
+ list[Scope]: scope instances
"""
return list(_traverse_scope(Scope(expression)))
@@ -356,9 +354,7 @@ def _traverse_scope(scope):
def _traverse_select(scope):
yield from _traverse_derived_tables(scope.ctes, scope, ScopeType.CTE)
yield from _traverse_subqueries(scope)
- yield from _traverse_derived_tables(
- scope.derived_tables, scope, ScopeType.DERIVED_TABLE
- )
+ yield from _traverse_derived_tables(scope.derived_tables, scope, ScopeType.DERIVED_TABLE)
_add_table_sources(scope)
@@ -367,15 +363,11 @@ def _traverse_union(scope):
# The last scope to be yield should be the top most scope
left = None
- for left in _traverse_scope(
- scope.branch(scope.expression.left, scope_type=ScopeType.UNION)
- ):
+ for left in _traverse_scope(scope.branch(scope.expression.left, scope_type=ScopeType.UNION)):
yield left
right = None
- for right in _traverse_scope(
- scope.branch(scope.expression.right, scope_type=ScopeType.UNION)
- ):
+ for right in _traverse_scope(scope.branch(scope.expression.right, scope_type=ScopeType.UNION)):
yield right
scope.union = (left, right)
@@ -387,14 +379,10 @@ def _traverse_derived_tables(derived_tables, scope, scope_type):
for derived_table in derived_tables:
for child_scope in _traverse_scope(
scope.branch(
- derived_table
- if isinstance(derived_table, (exp.Unnest, exp.Lateral))
- else derived_table.this,
+ derived_table if isinstance(derived_table, (exp.Unnest, exp.Lateral)) else derived_table.this,
add_sources=sources if scope_type == ScopeType.CTE else None,
outer_column_list=derived_table.alias_column_names,
- scope_type=ScopeType.UNNEST
- if isinstance(derived_table, exp.Unnest)
- else scope_type,
+ scope_type=ScopeType.UNNEST if isinstance(derived_table, exp.Unnest) else scope_type,
)
):
yield child_scope
@@ -430,9 +418,7 @@ def _add_table_sources(scope):
def _traverse_subqueries(scope):
for subquery in scope.subqueries:
top = None
- for child_scope in _traverse_scope(
- scope.branch(subquery, scope_type=ScopeType.SUBQUERY)
- ):
+ for child_scope in _traverse_scope(scope.branch(subquery, scope_type=ScopeType.SUBQUERY)):
yield child_scope
top = child_scope
scope.subquery_scopes.append(top)