summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/scope.py
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-08 08:11:53 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-08 08:12:02 +0000
commit8d36f5966675e23bee7026ba37ae0647fbf47300 (patch)
treedf4227bbb3b07cb70df87237bcff03c8efd7822d /sqlglot/optimizer/scope.py
parentReleasing debian version 22.2.0-1. (diff)
downloadsqlglot-8d36f5966675e23bee7026ba37ae0647fbf47300.tar.xz
sqlglot-8d36f5966675e23bee7026ba37ae0647fbf47300.zip
Merging upstream version 23.7.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/optimizer/scope.py')
-rw-r--r--sqlglot/optimizer/scope.py165
1 files changed, 91 insertions, 74 deletions
diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py
index 443fa6c..073ced2 100644
--- a/sqlglot/optimizer/scope.py
+++ b/sqlglot/optimizer/scope.py
@@ -8,7 +8,7 @@ from enum import Enum, auto
from sqlglot import exp
from sqlglot.errors import OptimizeError
-from sqlglot.helper import ensure_collection, find_new_name
+from sqlglot.helper import ensure_collection, find_new_name, seq_get
logger = logging.getLogger("sqlglot")
@@ -38,11 +38,11 @@ class Scope:
SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c;
The LATERAL VIEW EXPLODE gets x as a source.
cte_sources (dict[str, Scope]): Sources from CTES
- outer_column_list (list[str]): If this is a derived table or CTE, and the outer query
- defines a column list of it's alias of this scope, this is that list of columns.
+ outer_columns (list[str]): If this is a derived table or CTE, and the outer query
+ defines a column list for the alias of this scope, this is that list of columns.
For example:
SELECT * FROM (SELECT ...) AS y(col1, col2)
- The inner query would have `["col1", "col2"]` for its `outer_column_list`
+ The inner query would have `["col1", "col2"]` for its `outer_columns`
parent (Scope): Parent scope
scope_type (ScopeType): Type of this scope, relative to it's parent
subquery_scopes (list[Scope]): List of all child scopes for subqueries
@@ -58,7 +58,7 @@ class Scope:
self,
expression,
sources=None,
- outer_column_list=None,
+ outer_columns=None,
parent=None,
scope_type=ScopeType.ROOT,
lateral_sources=None,
@@ -70,7 +70,7 @@ class Scope:
self.cte_sources = cte_sources or {}
self.sources.update(self.lateral_sources)
self.sources.update(self.cte_sources)
- self.outer_column_list = outer_column_list or []
+ self.outer_columns = outer_columns or []
self.parent = parent
self.scope_type = scope_type
self.subquery_scopes = []
@@ -119,10 +119,11 @@ class Scope:
self._raw_columns = []
self._join_hints = []
- for node, parent, _ in self.walk(bfs=False):
+ for node in self.walk(bfs=False):
if node is self.expression:
continue
- elif isinstance(node, exp.Column) and not isinstance(node.this, exp.Star):
+
+ if isinstance(node, exp.Column) and not isinstance(node.this, exp.Star):
self._raw_columns.append(node)
elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint):
self._tables.append(node)
@@ -132,10 +133,8 @@ class Scope:
self._udtfs.append(node)
elif isinstance(node, exp.CTE):
self._ctes.append(node)
- elif (
- isinstance(node, exp.Subquery)
- and isinstance(parent, (exp.From, exp.Join, exp.Subquery))
- and _is_derived_table(node)
+ elif _is_derived_table(node) and isinstance(
+ node.parent, (exp.From, exp.Join, exp.Subquery)
):
self._derived_tables.append(node)
elif isinstance(node, exp.UNWRAPPED_QUERIES):
@@ -438,11 +437,21 @@ class Scope:
Yields:
Scope: scope instances in depth-first-search post-order
"""
- for child_scope in itertools.chain(
- self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes
- ):
- yield from child_scope.traverse()
- yield self
+ stack = [self]
+ result = []
+ while stack:
+ scope = stack.pop()
+ result.append(scope)
+ stack.extend(
+ itertools.chain(
+ scope.cte_scopes,
+ scope.union_scopes,
+ scope.table_scopes,
+ scope.subquery_scopes,
+ )
+ )
+
+ yield from reversed(result)
def ref_count(self):
"""
@@ -481,14 +490,28 @@ def traverse_scope(expression: exp.Expression) -> t.List[Scope]:
('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
Args:
- expression (exp.Expression): expression to traverse
+ expression: Expression to traverse
Returns:
- list[Scope]: scope instances
+ A list of the created scope instances
"""
- if isinstance(expression, exp.Query) or (
- isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Query)
- ):
+ if isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Query):
+ # We ignore the DDL expression and build a scope for its query instead
+ ddl_with = expression.args.get("with")
+ expression = expression.expression
+
+ # If the DDL has CTEs attached, we need to add them to the query, or
+ # prepend them if the query itself already has CTEs attached to it
+ if ddl_with:
+ ddl_with.pop()
+ query_ctes = expression.ctes
+ if not query_ctes:
+ expression.set("with", ddl_with)
+ else:
+ expression.args["with"].set("recursive", ddl_with.recursive)
+ expression.args["with"].set("expressions", [*ddl_with.expressions, *query_ctes])
+
+ if isinstance(expression, exp.Query):
return list(_traverse_scope(Scope(expression)))
return []
@@ -499,21 +522,21 @@ def build_scope(expression: exp.Expression) -> t.Optional[Scope]:
Build a scope tree.
Args:
- expression (exp.Expression): expression to build the scope tree for
+ expression: Expression to build the scope tree for.
+
Returns:
- Scope: root scope
+ The root scope
"""
- scopes = traverse_scope(expression)
- if scopes:
- return scopes[-1]
- return None
+ return seq_get(traverse_scope(expression), -1)
def _traverse_scope(scope):
if isinstance(scope.expression, exp.Select):
yield from _traverse_select(scope)
elif isinstance(scope.expression, exp.Union):
+ yield from _traverse_ctes(scope)
yield from _traverse_union(scope)
+ return
elif isinstance(scope.expression, exp.Subquery):
if scope.is_root:
yield from _traverse_select(scope)
@@ -523,8 +546,6 @@ def _traverse_scope(scope):
yield from _traverse_tables(scope)
elif isinstance(scope.expression, exp.UDTF):
yield from _traverse_udtfs(scope)
- elif isinstance(scope.expression, exp.DDL):
- yield from _traverse_ddl(scope)
else:
logger.warning(
"Cannot traverse scope %s with type '%s'", scope.expression, type(scope.expression)
@@ -541,30 +562,38 @@ def _traverse_select(scope):
def _traverse_union(scope):
- yield from _traverse_ctes(scope)
+ prev_scope = None
+ union_scope_stack = [scope]
+ expression_stack = [scope.expression.right, scope.expression.left]
- # The last scope to be yield should be the top most scope
- left = None
- for left in _traverse_scope(
- scope.branch(
- scope.expression.left,
- outer_column_list=scope.outer_column_list,
- scope_type=ScopeType.UNION,
- )
- ):
- yield left
+ while expression_stack:
+ expression = expression_stack.pop()
+ union_scope = union_scope_stack[-1]
- right = None
- for right in _traverse_scope(
- scope.branch(
- scope.expression.right,
- outer_column_list=scope.outer_column_list,
+ new_scope = union_scope.branch(
+ expression,
+ outer_columns=union_scope.outer_columns,
scope_type=ScopeType.UNION,
)
- ):
- yield right
- scope.union_scopes = [left, right]
+ if isinstance(expression, exp.Union):
+ yield from _traverse_ctes(new_scope)
+
+ union_scope_stack.append(new_scope)
+ expression_stack.extend([expression.right, expression.left])
+ continue
+
+ for scope in _traverse_scope(new_scope):
+ yield scope
+
+ if prev_scope:
+ union_scope_stack.pop()
+ union_scope.union_scopes = [prev_scope, scope]
+ prev_scope = union_scope
+
+ yield union_scope
+ else:
+ prev_scope = scope
def _traverse_ctes(scope):
@@ -588,7 +617,7 @@ def _traverse_ctes(scope):
scope.branch(
cte.this,
cte_sources=sources,
- outer_column_list=cte.alias_column_names,
+ outer_columns=cte.alias_column_names,
scope_type=ScopeType.CTE,
)
):
@@ -615,7 +644,9 @@ def _is_derived_table(expression: exp.Subquery) -> bool:
as it doesn't introduce a new scope. If an alias is present, it shadows all names
under the Subquery, so that's one exception to this rule.
"""
- return bool(expression.alias or isinstance(expression.this, exp.UNWRAPPED_QUERIES))
+ return isinstance(expression, exp.Subquery) and bool(
+ expression.alias or isinstance(expression.this, exp.UNWRAPPED_QUERIES)
+ )
def _traverse_tables(scope):
@@ -681,7 +712,7 @@ def _traverse_tables(scope):
scope.branch(
expression,
lateral_sources=lateral_sources,
- outer_column_list=expression.alias_column_names,
+ outer_columns=expression.alias_column_names,
scope_type=scope_type,
)
):
@@ -719,13 +750,13 @@ def _traverse_udtfs(scope):
sources = {}
for expression in expressions:
- if isinstance(expression, exp.Subquery) and _is_derived_table(expression):
+ if _is_derived_table(expression):
top = None
for child_scope in _traverse_scope(
scope.branch(
expression,
scope_type=ScopeType.DERIVED_TABLE,
- outer_column_list=expression.alias_column_names,
+ outer_columns=expression.alias_column_names,
)
):
yield child_scope
@@ -738,18 +769,6 @@ def _traverse_udtfs(scope):
scope.sources.update(sources)
-def _traverse_ddl(scope):
- yield from _traverse_ctes(scope)
-
- query_scope = scope.branch(
- scope.expression.expression, scope_type=ScopeType.DERIVED_TABLE, sources=scope.sources
- )
- query_scope._collect()
- query_scope._ctes = scope.ctes + query_scope._ctes
-
- yield from _traverse_scope(query_scope)
-
-
def walk_in_scope(expression, bfs=True, prune=None):
"""
Returns a generator object which visits all nodes in the syntrax tree, stopping at
@@ -769,23 +788,21 @@ def walk_in_scope(expression, bfs=True, prune=None):
# Whenever we set it to True, we exclude a subtree from traversal.
crossed_scope_boundary = False
- for node, parent, key in expression.walk(
- bfs=bfs, prune=lambda *args: crossed_scope_boundary or (prune and prune(*args))
+ for node in expression.walk(
+ bfs=bfs, prune=lambda n: crossed_scope_boundary or (prune and prune(n))
):
crossed_scope_boundary = False
- yield node, parent, key
+ yield node
if node is expression:
continue
if (
isinstance(node, exp.CTE)
or (
- isinstance(node, exp.Subquery)
- and isinstance(parent, (exp.From, exp.Join, exp.Subquery))
- and _is_derived_table(node)
+ isinstance(node.parent, (exp.From, exp.Join, exp.Subquery))
+ and (_is_derived_table(node) or isinstance(node, exp.UDTF))
)
- or isinstance(node, exp.UDTF)
or isinstance(node, exp.UNWRAPPED_QUERIES)
):
crossed_scope_boundary = True
@@ -812,7 +829,7 @@ def find_all_in_scope(expression, expression_types, bfs=True):
Yields:
exp.Expression: nodes
"""
- for expression, *_ in walk_in_scope(expression, bfs=bfs):
+ for expression in walk_in_scope(expression, bfs=bfs):
if isinstance(expression, tuple(ensure_collection(expression_types))):
yield expression