diff options
Diffstat (limited to 'sqlglot/optimizer/scope.py')
-rw-r--r-- | sqlglot/optimizer/scope.py | 87 |
1 files changed, 74 insertions, 13 deletions
diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index e816e10..be6cfb9 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -1,3 +1,4 @@ +import itertools from copy import copy from enum import Enum, auto @@ -32,10 +33,11 @@ class Scope: The inner query would have `["col1", "col2"]` for its `outer_column_list` parent (Scope): Parent scope scope_type (ScopeType): Type of this scope, relative to it's parent - subquery_scopes (list[Scope]): List of all child scopes for subqueries. - This does not include derived tables or CTEs. - union (tuple[Scope, Scope]): If this Scope is for a Union expression, this will be - a tuple of the left and right child scopes. + subquery_scopes (list[Scope]): List of all child scopes for subqueries + cte_scopes = (list[Scope]) List of all child scopes for CTEs + derived_table_scopes = (list[Scope]) List of all child scopes for derived_tables + union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be + a list of the left and right child scopes. """ def __init__( @@ -52,7 +54,9 @@ class Scope: self.parent = parent self.scope_type = scope_type self.subquery_scopes = [] - self.union = None + self.derived_table_scopes = [] + self.cte_scopes = [] + self.union_scopes = [] self.clear_cache() def clear_cache(self): @@ -197,11 +201,16 @@ class Scope: named_outputs = {e.alias_or_name for e in self.expression.expressions} - self._columns = [ - c - for c in columns + external_columns - if not (c.find_ancestor(exp.Qualify, exp.Order) and not c.table and c.name in named_outputs) - ] + self._columns = [] + for column in columns + external_columns: + ancestor = column.find_ancestor(exp.Qualify, exp.Order, exp.Hint) + if ( + not ancestor + or column.table + or (column.name not in named_outputs and not isinstance(ancestor, exp.Hint)) + ): + self._columns.append(column) + return self._columns @property @@ -284,6 +293,26 @@ class Scope: return self.scope_type == ScopeType.SUBQUERY @property + def is_derived_table(self): + """Determine if this scope is a derived table""" + return self.scope_type == ScopeType.DERIVED_TABLE + + @property + def is_union(self): + """Determine if this scope is a union""" + return self.scope_type == ScopeType.UNION + + @property + def is_cte(self): + """Determine if this scope is a common table expression""" + return self.scope_type == ScopeType.CTE + + @property + def is_root(self): + """Determine if this is the root scope""" + return self.scope_type == ScopeType.ROOT + + @property def is_unnest(self): """Determine if this scope is an unnest""" return self.scope_type == ScopeType.UNNEST @@ -308,6 +337,22 @@ class Scope: self.sources.pop(name, None) self.clear_cache() + def __repr__(self): + return f"Scope<{self.expression.sql()}>" + + def traverse(self): + """ + Traverse the scope tree from this node. + + Yields: + Scope: scope instances in depth-first-search post-order + """ + for child_scope in itertools.chain( + self.cte_scopes, self.union_scopes, self.subquery_scopes, self.derived_table_scopes + ): + yield from child_scope.traverse() + yield self + def traverse_scope(expression): """ @@ -337,6 +382,18 @@ def traverse_scope(expression): return list(_traverse_scope(Scope(expression))) +def build_scope(expression): + """ + Build a scope tree. + + Args: + expression (exp.Expression): expression to build the scope tree for + Returns: + Scope: root scope + """ + return traverse_scope(expression)[-1] + + def _traverse_scope(scope): if isinstance(scope.expression, exp.Select): yield from _traverse_select(scope) @@ -370,13 +427,14 @@ def _traverse_union(scope): for right in _traverse_scope(scope.branch(scope.expression.right, scope_type=ScopeType.UNION)): yield right - scope.union = (left, right) + scope.union_scopes = [left, right] def _traverse_derived_tables(derived_tables, scope, scope_type): sources = {} for derived_table in derived_tables: + top = None for child_scope in _traverse_scope( scope.branch( derived_table if isinstance(derived_table, (exp.Unnest, exp.Lateral)) else derived_table.this, @@ -386,11 +444,16 @@ def _traverse_derived_tables(derived_tables, scope, scope_type): ) ): yield child_scope + top = child_scope # Tables without aliases will be set as "" # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything. # Until then, this means that only a single, unaliased derived table is allowed (rather, # the latest one wins. sources[derived_table.alias] = child_scope + if scope_type == ScopeType.CTE: + scope.cte_scopes.append(top) + else: + scope.derived_table_scopes.append(top) scope.sources.update(sources) @@ -407,8 +470,6 @@ def _add_table_sources(scope): if table_name in scope.sources: # This is a reference to a parent source (e.g. a CTE), not an actual table. scope.sources[source_name] = scope.sources[table_name] - elif source_name in scope.sources: - raise OptimizeError(f"Duplicate table name: {source_name}") else: sources[source_name] = table |