From 9f19773cebdc9476f2a3266d3c01c967c38fcd1e Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Fri, 30 Jun 2023 10:03:58 +0200 Subject: Merging upstream version 16.7.7. Signed-off-by: Daniel Baumann --- sqlglot/optimizer/scope.py | 58 ++++++++++++++++++++++++++++++++++------------ 1 file changed, 43 insertions(+), 15 deletions(-) (limited to 'sqlglot/optimizer/scope.py') diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index aa56b83..bc649e4 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -1,4 +1,5 @@ import itertools +import logging import typing as t from collections import defaultdict from enum import Enum, auto @@ -7,6 +8,8 @@ from sqlglot import exp from sqlglot.errors import OptimizeError from sqlglot.helper import find_new_name +logger = logging.getLogger("sqlglot") + class ScopeType(Enum): ROOT = auto() @@ -85,6 +88,7 @@ class Scope: self._external_columns = None self._join_hints = None self._pivots = None + self._references = None def branch(self, expression, scope_type, chain_sources=None, **kwargs): """Branch from the current scope to a new, inner scope""" @@ -264,14 +268,19 @@ class Scope: self._columns = [] for column in columns + external_columns: ancestor = column.find_ancestor( - exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint + exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint, exp.Table ) if ( not ancestor or column.table or isinstance(ancestor, exp.Select) - or (isinstance(ancestor, exp.Order) and isinstance(ancestor.parent, exp.Window)) - or (column.name not in named_selects and not isinstance(ancestor, exp.Hint)) + or ( + isinstance(ancestor, exp.Order) + and ( + isinstance(ancestor.parent, exp.Window) + or column.name not in named_selects + ) + ) ): self._columns.append(column) @@ -289,15 +298,9 @@ class Scope: dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes """ if self._selected_sources is None: - referenced_names = [] - - for table in self.tables: - referenced_names.append((table.alias_or_name, table)) - for expression in itertools.chain(self.derived_tables, self.udtfs): - referenced_names.append((expression.alias, expression.unnest())) result = {} - for name, node in referenced_names: + for name, node in self.references: if name in result: raise OptimizeError(f"Alias already used: {name}") if name in self.sources: @@ -306,6 +309,23 @@ class Scope: self._selected_sources = result return self._selected_sources + @property + def references(self) -> t.List[t.Tuple[str, exp.Expression]]: + if self._references is None: + self._references = [] + + for table in self.tables: + self._references.append((table.alias_or_name, table)) + for expression in itertools.chain(self.derived_tables, self.udtfs): + self._references.append( + ( + expression.alias, + expression if expression.args.get("pivots") else expression.unnest(), + ) + ) + + return self._references + @property def cte_sources(self): """ @@ -378,9 +398,7 @@ class Scope: def pivots(self): if not self._pivots: self._pivots = [ - pivot - for node in self.tables + self.derived_tables - for pivot in node.args.get("pivots") or [] + pivot for _, node in self.references for pivot in node.args.get("pivots") or [] ] return self._pivots @@ -536,7 +554,11 @@ def _traverse_scope(scope): elif isinstance(scope.expression, exp.UDTF): pass else: - raise OptimizeError(f"Unexpected expression type: {type(scope.expression)}") + logger.warning( + "Cannot traverse scope %s with type '%s'", scope.expression, type(scope.expression) + ) + return + yield scope @@ -576,6 +598,8 @@ def _traverse_ctes(scope): if isinstance(union, exp.Union): recursive_scope = scope.branch(union.this, scope_type=ScopeType.CTE) + child_scope = None + for child_scope in _traverse_scope( scope.branch( cte.this, @@ -593,7 +617,8 @@ def _traverse_ctes(scope): child_scope.add_source(alias, recursive_scope) # append the final child_scope yielded - scope.cte_scopes.append(child_scope) + if child_scope: + scope.cte_scopes.append(child_scope) scope.sources.update(sources) @@ -634,6 +659,9 @@ def _traverse_tables(scope): sources[source_name] = expression continue + if not isinstance(expression, exp.DerivedTable): + continue + if isinstance(expression, exp.UDTF): lateral_sources = sources scope_type = ScopeType.UDTF -- cgit v1.2.3