summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/scope.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer/scope.py')
-rw-r--r--sqlglot/optimizer/scope.py87
1 files changed, 74 insertions, 13 deletions
diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py
index e816e10..be6cfb9 100644
--- a/sqlglot/optimizer/scope.py
+++ b/sqlglot/optimizer/scope.py
@@ -1,3 +1,4 @@
+import itertools
from copy import copy
from enum import Enum, auto
@@ -32,10 +33,11 @@ class Scope:
The inner query would have `["col1", "col2"]` for its `outer_column_list`
parent (Scope): Parent scope
scope_type (ScopeType): Type of this scope, relative to it's parent
- subquery_scopes (list[Scope]): List of all child scopes for subqueries.
- This does not include derived tables or CTEs.
- union (tuple[Scope, Scope]): If this Scope is for a Union expression, this will be
- a tuple of the left and right child scopes.
+ subquery_scopes (list[Scope]): List of all child scopes for subqueries
+ cte_scopes = (list[Scope]) List of all child scopes for CTEs
+ derived_table_scopes = (list[Scope]) List of all child scopes for derived_tables
+ union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be
+ a list of the left and right child scopes.
"""
def __init__(
@@ -52,7 +54,9 @@ class Scope:
self.parent = parent
self.scope_type = scope_type
self.subquery_scopes = []
- self.union = None
+ self.derived_table_scopes = []
+ self.cte_scopes = []
+ self.union_scopes = []
self.clear_cache()
def clear_cache(self):
@@ -197,11 +201,16 @@ class Scope:
named_outputs = {e.alias_or_name for e in self.expression.expressions}
- self._columns = [
- c
- for c in columns + external_columns
- if not (c.find_ancestor(exp.Qualify, exp.Order) and not c.table and c.name in named_outputs)
- ]
+ self._columns = []
+ for column in columns + external_columns:
+ ancestor = column.find_ancestor(exp.Qualify, exp.Order, exp.Hint)
+ if (
+ not ancestor
+ or column.table
+ or (column.name not in named_outputs and not isinstance(ancestor, exp.Hint))
+ ):
+ self._columns.append(column)
+
return self._columns
@property
@@ -284,6 +293,26 @@ class Scope:
return self.scope_type == ScopeType.SUBQUERY
@property
+ def is_derived_table(self):
+ """Determine if this scope is a derived table"""
+ return self.scope_type == ScopeType.DERIVED_TABLE
+
+ @property
+ def is_union(self):
+ """Determine if this scope is a union"""
+ return self.scope_type == ScopeType.UNION
+
+ @property
+ def is_cte(self):
+ """Determine if this scope is a common table expression"""
+ return self.scope_type == ScopeType.CTE
+
+ @property
+ def is_root(self):
+ """Determine if this is the root scope"""
+ return self.scope_type == ScopeType.ROOT
+
+ @property
def is_unnest(self):
"""Determine if this scope is an unnest"""
return self.scope_type == ScopeType.UNNEST
@@ -308,6 +337,22 @@ class Scope:
self.sources.pop(name, None)
self.clear_cache()
+ def __repr__(self):
+ return f"Scope<{self.expression.sql()}>"
+
+ def traverse(self):
+ """
+ Traverse the scope tree from this node.
+
+ Yields:
+ Scope: scope instances in depth-first-search post-order
+ """
+ for child_scope in itertools.chain(
+ self.cte_scopes, self.union_scopes, self.subquery_scopes, self.derived_table_scopes
+ ):
+ yield from child_scope.traverse()
+ yield self
+
def traverse_scope(expression):
"""
@@ -337,6 +382,18 @@ def traverse_scope(expression):
return list(_traverse_scope(Scope(expression)))
+def build_scope(expression):
+ """
+ Build a scope tree.
+
+ Args:
+ expression (exp.Expression): expression to build the scope tree for
+ Returns:
+ Scope: root scope
+ """
+ return traverse_scope(expression)[-1]
+
+
def _traverse_scope(scope):
if isinstance(scope.expression, exp.Select):
yield from _traverse_select(scope)
@@ -370,13 +427,14 @@ def _traverse_union(scope):
for right in _traverse_scope(scope.branch(scope.expression.right, scope_type=ScopeType.UNION)):
yield right
- scope.union = (left, right)
+ scope.union_scopes = [left, right]
def _traverse_derived_tables(derived_tables, scope, scope_type):
sources = {}
for derived_table in derived_tables:
+ top = None
for child_scope in _traverse_scope(
scope.branch(
derived_table if isinstance(derived_table, (exp.Unnest, exp.Lateral)) else derived_table.this,
@@ -386,11 +444,16 @@ def _traverse_derived_tables(derived_tables, scope, scope_type):
)
):
yield child_scope
+ top = child_scope
# Tables without aliases will be set as ""
# This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything.
# Until then, this means that only a single, unaliased derived table is allowed (rather,
# the latest one wins.
sources[derived_table.alias] = child_scope
+ if scope_type == ScopeType.CTE:
+ scope.cte_scopes.append(top)
+ else:
+ scope.derived_table_scopes.append(top)
scope.sources.update(sources)
@@ -407,8 +470,6 @@ def _add_table_sources(scope):
if table_name in scope.sources:
# This is a reference to a parent source (e.g. a CTE), not an actual table.
scope.sources[source_name] = scope.sources[table_name]
- elif source_name in scope.sources:
- raise OptimizeError(f"Duplicate table name: {source_name}")
else:
sources[source_name] = table