summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer')
-rw-r--r--sqlglot/optimizer/merge_subqueries.py13
-rw-r--r--sqlglot/optimizer/qualify_columns.py10
-rw-r--r--sqlglot/optimizer/scope.py122
3 files changed, 105 insertions, 40 deletions
diff --git a/sqlglot/optimizer/merge_subqueries.py b/sqlglot/optimizer/merge_subqueries.py
index 9d966b7..d29c22b 100644
--- a/sqlglot/optimizer/merge_subqueries.py
+++ b/sqlglot/optimizer/merge_subqueries.py
@@ -32,8 +32,8 @@ def merge_subqueries(expression, leave_tables_isolated=False):
Returns:
sqlglot.Expression: optimized expression
"""
- merge_ctes(expression, leave_tables_isolated)
- merge_derived_tables(expression, leave_tables_isolated)
+ expression = merge_ctes(expression, leave_tables_isolated)
+ expression = merge_derived_tables(expression, leave_tables_isolated)
return expression
@@ -76,14 +76,14 @@ def merge_ctes(expression, leave_tables_isolated=False):
alias = node_to_replace.alias
else:
alias = table.name
-
_rename_inner_sources(outer_scope, inner_scope, alias)
_merge_from(outer_scope, inner_scope, node_to_replace, alias)
- _merge_joins(outer_scope, inner_scope, from_or_join)
_merge_expressions(outer_scope, inner_scope, alias)
+ _merge_joins(outer_scope, inner_scope, from_or_join)
_merge_where(outer_scope, inner_scope, from_or_join)
_merge_order(outer_scope, inner_scope)
_pop_cte(inner_scope)
+ return expression
def merge_derived_tables(expression, leave_tables_isolated=False):
@@ -97,10 +97,11 @@ def merge_derived_tables(expression, leave_tables_isolated=False):
_rename_inner_sources(outer_scope, inner_scope, alias)
_merge_from(outer_scope, inner_scope, subquery, alias)
- _merge_joins(outer_scope, inner_scope, from_or_join)
_merge_expressions(outer_scope, inner_scope, alias)
+ _merge_joins(outer_scope, inner_scope, from_or_join)
_merge_where(outer_scope, inner_scope, from_or_join)
_merge_order(outer_scope, inner_scope)
+ return expression
def _mergeable(outer_scope, inner_select, leave_tables_isolated):
@@ -229,7 +230,7 @@ def _merge_expressions(outer_scope, inner_scope, alias):
continue
columns_to_replace = outer_columns.get(projection_name, [])
for column in columns_to_replace:
- column.replace(expression.unalias())
+ column.replace(expression.unalias().copy())
def _merge_where(outer_scope, inner_scope, from_or_join):
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py
index 0bb947a..72ce256 100644
--- a/sqlglot/optimizer/qualify_columns.py
+++ b/sqlglot/optimizer/qualify_columns.py
@@ -5,8 +5,6 @@ from sqlglot.errors import OptimizeError
from sqlglot.optimizer.schema import ensure_schema
from sqlglot.optimizer.scope import traverse_scope
-SKIP_QUALIFY = (exp.Unnest, exp.Lateral)
-
def qualify_columns(expression, schema):
"""
@@ -35,7 +33,7 @@ def qualify_columns(expression, schema):
_expand_group_by(scope, resolver)
_expand_order_by(scope)
_qualify_columns(scope, resolver)
- if not isinstance(scope.expression, SKIP_QUALIFY):
+ if not isinstance(scope.expression, exp.UDTF):
_expand_stars(scope, resolver)
_qualify_outputs(scope)
_check_unknown_tables(scope)
@@ -50,7 +48,7 @@ def _pop_table_column_aliases(derived_tables):
(e.g. SELECT ... FROM (SELECT ...) AS foo(col1, col2)
"""
for derived_table in derived_tables:
- if isinstance(derived_table, SKIP_QUALIFY):
+ if isinstance(derived_table, exp.UDTF):
continue
table_alias = derived_table.args.get("alias")
if table_alias:
@@ -202,7 +200,7 @@ def _qualify_columns(scope, resolver):
if not column_table:
column_table = resolver.get_table(column_name)
- if not scope.is_subquery and not scope.is_unnest:
+ if not scope.is_subquery and not scope.is_udtf:
if column_name not in resolver.all_columns:
raise OptimizeError(f"Unknown column: {column_name}")
@@ -296,7 +294,7 @@ def _qualify_outputs(scope):
def _check_unknown_tables(scope):
- if scope.external_columns and not scope.is_unnest and not scope.is_correlated_subquery:
+ if scope.external_columns and not scope.is_udtf and not scope.is_correlated_subquery:
raise OptimizeError(f"Unknown table: {scope.external_columns[0].text('table')}")
diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py
index be6cfb9..6332cdd 100644
--- a/sqlglot/optimizer/scope.py
+++ b/sqlglot/optimizer/scope.py
@@ -1,5 +1,4 @@
import itertools
-from copy import copy
from enum import Enum, auto
from sqlglot import exp
@@ -12,7 +11,7 @@ class ScopeType(Enum):
DERIVED_TABLE = auto()
CTE = auto()
UNION = auto()
- UNNEST = auto()
+ UDTF = auto()
class Scope:
@@ -70,14 +69,11 @@ class Scope:
self._columns = None
self._external_columns = None
- def branch(self, expression, scope_type, add_sources=None, **kwargs):
+ def branch(self, expression, scope_type, chain_sources=None, **kwargs):
"""Branch from the current scope to a new, inner scope"""
- sources = copy(self.sources)
- if add_sources:
- sources.update(add_sources)
return Scope(
expression=expression.unnest(),
- sources=sources,
+ sources={**self.cte_sources, **(chain_sources or {})},
parent=self,
scope_type=scope_type,
**kwargs,
@@ -90,30 +86,21 @@ class Scope:
self._derived_tables = []
self._raw_columns = []
- # We'll use this variable to pass state into the dfs generator.
- # Whenever we set it to True, we exclude a subtree from traversal.
- prune = False
-
- for node, parent, _ in self.expression.dfs(prune=lambda *_: prune):
- prune = False
-
+ for node, parent, _ in self.walk(bfs=False):
if node is self.expression:
continue
- if isinstance(node, exp.Column) and not isinstance(node.this, exp.Star):
+ elif isinstance(node, exp.Column) and not isinstance(node.this, exp.Star):
self._raw_columns.append(node)
elif isinstance(node, exp.Table):
self._tables.append(node)
- elif isinstance(node, (exp.Unnest, exp.Lateral)):
+ elif isinstance(node, exp.UDTF):
self._derived_tables.append(node)
elif isinstance(node, exp.CTE):
self._ctes.append(node)
- prune = True
elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)):
self._derived_tables.append(node)
- prune = True
elif isinstance(node, exp.Subqueryable):
self._subqueries.append(node)
- prune = True
self._collected = True
@@ -121,6 +108,43 @@ class Scope:
if not self._collected:
self._collect()
+ def walk(self, bfs=True):
+ return walk_in_scope(self.expression, bfs=bfs)
+
+ def find(self, *expression_types, bfs=True):
+ """
+ Returns the first node in this scope which matches at least one of the specified types.
+
+ This does NOT traverse into subscopes.
+
+ Args:
+ expression_types (type): the expression type(s) to match.
+ bfs (bool): True to use breadth-first search, False to use depth-first.
+
+ Returns:
+ exp.Expression: the node which matches the criteria or None if no node matching
+ the criteria was found.
+ """
+ return next(self.find_all(*expression_types, bfs=bfs), None)
+
+ def find_all(self, *expression_types, bfs=True):
+ """
+ Returns a generator object which visits all nodes in this scope and only yields those that
+ match at least one of the specified expression types.
+
+ This does NOT traverse into subscopes.
+
+ Args:
+ expression_types (type): the expression type(s) to match.
+ bfs (bool): True to use breadth-first search, False to use depth-first.
+
+ Yields:
+ exp.Expression: nodes
+ """
+ for expression, _, _ in self.walk(bfs=bfs):
+ if isinstance(expression, expression_types):
+ yield expression
+
def replace(self, old, new):
"""
Replace `old` with `new`.
@@ -247,6 +271,16 @@ class Scope:
return self._selected_sources
@property
+ def cte_sources(self):
+ """
+ Sources that are CTEs.
+
+ Returns:
+ dict[str, Scope]: Mapping of source alias to Scope
+ """
+ return {alias: scope for alias, scope in self.sources.items() if isinstance(scope, Scope) and scope.is_cte}
+
+ @property
def selects(self):
"""
Select expressions of this scope.
@@ -313,9 +347,9 @@ class Scope:
return self.scope_type == ScopeType.ROOT
@property
- def is_unnest(self):
- """Determine if this scope is an unnest"""
- return self.scope_type == ScopeType.UNNEST
+ def is_udtf(self):
+ """Determine if this scope is a UDTF (User Defined Table Function)"""
+ return self.scope_type == ScopeType.UDTF
@property
def is_correlated_subquery(self):
@@ -348,7 +382,7 @@ class Scope:
Scope: scope instances in depth-first-search post-order
"""
for child_scope in itertools.chain(
- self.cte_scopes, self.union_scopes, self.subquery_scopes, self.derived_table_scopes
+ self.cte_scopes, self.union_scopes, self.derived_table_scopes, self.subquery_scopes
):
yield from child_scope.traverse()
yield self
@@ -399,7 +433,7 @@ def _traverse_scope(scope):
yield from _traverse_select(scope)
elif isinstance(scope.expression, exp.Union):
yield from _traverse_union(scope)
- elif isinstance(scope.expression, (exp.Lateral, exp.Unnest)):
+ elif isinstance(scope.expression, exp.UDTF):
pass
elif isinstance(scope.expression, exp.Subquery):
yield from _traverse_subqueries(scope)
@@ -410,8 +444,8 @@ def _traverse_scope(scope):
def _traverse_select(scope):
yield from _traverse_derived_tables(scope.ctes, scope, ScopeType.CTE)
- yield from _traverse_subqueries(scope)
yield from _traverse_derived_tables(scope.derived_tables, scope, ScopeType.DERIVED_TABLE)
+ yield from _traverse_subqueries(scope)
_add_table_sources(scope)
@@ -437,10 +471,10 @@ def _traverse_derived_tables(derived_tables, scope, scope_type):
top = None
for child_scope in _traverse_scope(
scope.branch(
- derived_table if isinstance(derived_table, (exp.Unnest, exp.Lateral)) else derived_table.this,
- add_sources=sources if scope_type == ScopeType.CTE else None,
+ derived_table if isinstance(derived_table, exp.UDTF) else derived_table.this,
+ chain_sources=sources if scope_type == ScopeType.CTE else None,
outer_column_list=derived_table.alias_column_names,
- scope_type=ScopeType.UNNEST if isinstance(derived_table, exp.Unnest) else scope_type,
+ scope_type=ScopeType.UDTF if isinstance(derived_table, exp.UDTF) else scope_type,
)
):
yield child_scope
@@ -483,3 +517,35 @@ def _traverse_subqueries(scope):
yield child_scope
top = child_scope
scope.subquery_scopes.append(top)
+
+
+def walk_in_scope(expression, bfs=True):
+ """
+ Returns a generator object which visits all nodes in the syntrax tree, stopping at
+ nodes that start child scopes.
+
+ Args:
+ expression (exp.Expression):
+ bfs (bool): if set to True the BFS traversal order will be applied,
+ otherwise the DFS traversal will be used instead.
+
+ Yields:
+ tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key
+ """
+ # We'll use this variable to pass state into the dfs generator.
+ # Whenever we set it to True, we exclude a subtree from traversal.
+ prune = False
+
+ for node, parent, key in expression.walk(bfs=bfs, prune=lambda *_: prune):
+ prune = False
+
+ yield node, parent, key
+
+ if node is expression:
+ continue
+ elif isinstance(node, exp.CTE):
+ prune = True
+ elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)):
+ prune = True
+ elif isinstance(node, exp.Subqueryable):
+ prune = True