summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/scope.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer/scope.py')
-rw-r--r--sqlglot/optimizer/scope.py58
1 files changed, 43 insertions, 15 deletions
diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py
index aa56b83..bc649e4 100644
--- a/sqlglot/optimizer/scope.py
+++ b/sqlglot/optimizer/scope.py
@@ -1,4 +1,5 @@
import itertools
+import logging
import typing as t
from collections import defaultdict
from enum import Enum, auto
@@ -7,6 +8,8 @@ from sqlglot import exp
from sqlglot.errors import OptimizeError
from sqlglot.helper import find_new_name
+logger = logging.getLogger("sqlglot")
+
class ScopeType(Enum):
ROOT = auto()
@@ -85,6 +88,7 @@ class Scope:
self._external_columns = None
self._join_hints = None
self._pivots = None
+ self._references = None
def branch(self, expression, scope_type, chain_sources=None, **kwargs):
"""Branch from the current scope to a new, inner scope"""
@@ -264,14 +268,19 @@ class Scope:
self._columns = []
for column in columns + external_columns:
ancestor = column.find_ancestor(
- exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint
+ exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint, exp.Table
)
if (
not ancestor
or column.table
or isinstance(ancestor, exp.Select)
- or (isinstance(ancestor, exp.Order) and isinstance(ancestor.parent, exp.Window))
- or (column.name not in named_selects and not isinstance(ancestor, exp.Hint))
+ or (
+ isinstance(ancestor, exp.Order)
+ and (
+ isinstance(ancestor.parent, exp.Window)
+ or column.name not in named_selects
+ )
+ )
):
self._columns.append(column)
@@ -289,15 +298,9 @@ class Scope:
dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes
"""
if self._selected_sources is None:
- referenced_names = []
-
- for table in self.tables:
- referenced_names.append((table.alias_or_name, table))
- for expression in itertools.chain(self.derived_tables, self.udtfs):
- referenced_names.append((expression.alias, expression.unnest()))
result = {}
- for name, node in referenced_names:
+ for name, node in self.references:
if name in result:
raise OptimizeError(f"Alias already used: {name}")
if name in self.sources:
@@ -307,6 +310,23 @@ class Scope:
return self._selected_sources
@property
+ def references(self) -> t.List[t.Tuple[str, exp.Expression]]:
+ if self._references is None:
+ self._references = []
+
+ for table in self.tables:
+ self._references.append((table.alias_or_name, table))
+ for expression in itertools.chain(self.derived_tables, self.udtfs):
+ self._references.append(
+ (
+ expression.alias,
+ expression if expression.args.get("pivots") else expression.unnest(),
+ )
+ )
+
+ return self._references
+
+ @property
def cte_sources(self):
"""
Sources that are CTEs.
@@ -378,9 +398,7 @@ class Scope:
def pivots(self):
if not self._pivots:
self._pivots = [
- pivot
- for node in self.tables + self.derived_tables
- for pivot in node.args.get("pivots") or []
+ pivot for _, node in self.references for pivot in node.args.get("pivots") or []
]
return self._pivots
@@ -536,7 +554,11 @@ def _traverse_scope(scope):
elif isinstance(scope.expression, exp.UDTF):
pass
else:
- raise OptimizeError(f"Unexpected expression type: {type(scope.expression)}")
+ logger.warning(
+ "Cannot traverse scope %s with type '%s'", scope.expression, type(scope.expression)
+ )
+ return
+
yield scope
@@ -576,6 +598,8 @@ def _traverse_ctes(scope):
if isinstance(union, exp.Union):
recursive_scope = scope.branch(union.this, scope_type=ScopeType.CTE)
+ child_scope = None
+
for child_scope in _traverse_scope(
scope.branch(
cte.this,
@@ -593,7 +617,8 @@ def _traverse_ctes(scope):
child_scope.add_source(alias, recursive_scope)
# append the final child_scope yielded
- scope.cte_scopes.append(child_scope)
+ if child_scope:
+ scope.cte_scopes.append(child_scope)
scope.sources.update(sources)
@@ -634,6 +659,9 @@ def _traverse_tables(scope):
sources[source_name] = expression
continue
+ if not isinstance(expression, exp.DerivedTable):
+ continue
+
if isinstance(expression, exp.UDTF):
lateral_sources = sources
scope_type = ScopeType.UDTF