diff options
Diffstat (limited to 'sqlglot/optimizer/scope.py')
-rw-r--r-- | sqlglot/optimizer/scope.py | 43 |
1 files changed, 34 insertions, 9 deletions
diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index e00b3c9..9ffb4d6 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -1,4 +1,5 @@ import itertools +import typing as t from collections import defaultdict from enum import Enum, auto @@ -83,6 +84,7 @@ class Scope: self._columns = None self._external_columns = None self._join_hints = None + self._pivots = None def branch(self, expression, scope_type, chain_sources=None, **kwargs): """Branch from the current scope to a new, inner scope""" @@ -261,12 +263,14 @@ class Scope: self._columns = [] for column in columns + external_columns: - ancestor = column.find_ancestor(exp.Qualify, exp.Order, exp.Having, exp.Hint) + ancestor = column.find_ancestor( + exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint + ) if ( not ancestor - # Window functions can have an ORDER BY clause - or not isinstance(ancestor.parent, exp.Select) or column.table + or isinstance(ancestor, exp.Select) + or (isinstance(ancestor, exp.Order) and isinstance(ancestor.parent, exp.Window)) or (column.name not in named_selects and not isinstance(ancestor, exp.Hint)) ): self._columns.append(column) @@ -370,6 +374,17 @@ class Scope: return [] return self._join_hints + @property + def pivots(self): + if not self._pivots: + self._pivots = [ + pivot + for node in self.tables + self.derived_tables + for pivot in node.args.get("pivots") or [] + ] + + return self._pivots + def source_columns(self, source_name): """ Get all columns in the current scope for a particular source. @@ -463,7 +478,7 @@ class Scope: return scope_ref_count -def traverse_scope(expression): +def traverse_scope(expression: exp.Expression) -> t.List[Scope]: """ Traverse an expression by it's "scopes". @@ -488,10 +503,12 @@ def traverse_scope(expression): Returns: list[Scope]: scope instances """ + if not isinstance(expression, exp.Unionable): + return [] return list(_traverse_scope(Scope(expression))) -def build_scope(expression): +def build_scope(expression: exp.Expression) -> t.Optional[Scope]: """ Build a scope tree. @@ -500,7 +517,10 @@ def build_scope(expression): Returns: Scope: root scope """ - return traverse_scope(expression)[-1] + scopes = traverse_scope(expression) + if scopes: + return scopes[-1] + return None def _traverse_scope(scope): @@ -585,7 +605,7 @@ def _traverse_tables(scope): expressions = [] from_ = scope.expression.args.get("from") if from_: - expressions.extend(from_.expressions) + expressions.append(from_.this) for join in scope.expression.args.get("joins") or []: expressions.append(join.this) @@ -601,8 +621,13 @@ def _traverse_tables(scope): source_name = expression.alias_or_name if table_name in scope.sources: - # This is a reference to a parent source (e.g. a CTE), not an actual table. - sources[source_name] = scope.sources[table_name] + # This is a reference to a parent source (e.g. a CTE), not an actual table, unless + # it is pivoted, because then we get back a new table and hence a new source. + pivots = expression.args.get("pivots") + if pivots: + sources[pivots[0].alias] = expression + else: + sources[source_name] = scope.sources[table_name] elif source_name in sources: sources[find_new_name(sources, table_name)] = expression else: |