summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/scope.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer/scope.py')
-rw-r--r--sqlglot/optimizer/scope.py43
1 files changed, 34 insertions, 9 deletions
diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py
index e00b3c9..9ffb4d6 100644
--- a/sqlglot/optimizer/scope.py
+++ b/sqlglot/optimizer/scope.py
@@ -1,4 +1,5 @@
import itertools
+import typing as t
from collections import defaultdict
from enum import Enum, auto
@@ -83,6 +84,7 @@ class Scope:
self._columns = None
self._external_columns = None
self._join_hints = None
+ self._pivots = None
def branch(self, expression, scope_type, chain_sources=None, **kwargs):
"""Branch from the current scope to a new, inner scope"""
@@ -261,12 +263,14 @@ class Scope:
self._columns = []
for column in columns + external_columns:
- ancestor = column.find_ancestor(exp.Qualify, exp.Order, exp.Having, exp.Hint)
+ ancestor = column.find_ancestor(
+ exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint
+ )
if (
not ancestor
- # Window functions can have an ORDER BY clause
- or not isinstance(ancestor.parent, exp.Select)
or column.table
+ or isinstance(ancestor, exp.Select)
+ or (isinstance(ancestor, exp.Order) and isinstance(ancestor.parent, exp.Window))
or (column.name not in named_selects and not isinstance(ancestor, exp.Hint))
):
self._columns.append(column)
@@ -370,6 +374,17 @@ class Scope:
return []
return self._join_hints
+ @property
+ def pivots(self):
+ if not self._pivots:
+ self._pivots = [
+ pivot
+ for node in self.tables + self.derived_tables
+ for pivot in node.args.get("pivots") or []
+ ]
+
+ return self._pivots
+
def source_columns(self, source_name):
"""
Get all columns in the current scope for a particular source.
@@ -463,7 +478,7 @@ class Scope:
return scope_ref_count
-def traverse_scope(expression):
+def traverse_scope(expression: exp.Expression) -> t.List[Scope]:
"""
Traverse an expression by it's "scopes".
@@ -488,10 +503,12 @@ def traverse_scope(expression):
Returns:
list[Scope]: scope instances
"""
+ if not isinstance(expression, exp.Unionable):
+ return []
return list(_traverse_scope(Scope(expression)))
-def build_scope(expression):
+def build_scope(expression: exp.Expression) -> t.Optional[Scope]:
"""
Build a scope tree.
@@ -500,7 +517,10 @@ def build_scope(expression):
Returns:
Scope: root scope
"""
- return traverse_scope(expression)[-1]
+ scopes = traverse_scope(expression)
+ if scopes:
+ return scopes[-1]
+ return None
def _traverse_scope(scope):
@@ -585,7 +605,7 @@ def _traverse_tables(scope):
expressions = []
from_ = scope.expression.args.get("from")
if from_:
- expressions.extend(from_.expressions)
+ expressions.append(from_.this)
for join in scope.expression.args.get("joins") or []:
expressions.append(join.this)
@@ -601,8 +621,13 @@ def _traverse_tables(scope):
source_name = expression.alias_or_name
if table_name in scope.sources:
- # This is a reference to a parent source (e.g. a CTE), not an actual table.
- sources[source_name] = scope.sources[table_name]
+ # This is a reference to a parent source (e.g. a CTE), not an actual table, unless
+ # it is pivoted, because then we get back a new table and hence a new source.
+ pivots = expression.args.get("pivots")
+ if pivots:
+ sources[pivots[0].alias] = expression
+ else:
+ sources[source_name] = scope.sources[table_name]
elif source_name in sources:
sources[find_new_name(sources, table_name)] = expression
else: