summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/scope.py
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-05-04 16:13:01 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-05-04 16:13:01 +0000
commita7044b672667f2a0b48bd0b326b5a55b0815ef79 (patch)
tree4fb5238d47fb4709d47f766a74b8bbaa9c6f17d8 /sqlglot/optimizer/scope.py
parentReleasing debian version 23.12.1-1. (diff)
downloadsqlglot-a7044b672667f2a0b48bd0b326b5a55b0815ef79.tar.xz
sqlglot-a7044b672667f2a0b48bd0b326b5a55b0815ef79.zip
Merging upstream version 23.13.1.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/optimizer/scope.py')
-rw-r--r--sqlglot/optimizer/scope.py54
1 files changed, 25 insertions, 29 deletions
diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py
index c589e24..cff75e4 100644
--- a/sqlglot/optimizer/scope.py
+++ b/sqlglot/optimizer/scope.py
@@ -12,6 +12,8 @@ from sqlglot.helper import ensure_collection, find_new_name, seq_get
logger = logging.getLogger("sqlglot")
+TRAVERSABLES = (exp.Query, exp.DDL, exp.DML)
+
class ScopeType(Enum):
ROOT = auto()
@@ -495,25 +497,8 @@ def traverse_scope(expression: exp.Expression) -> t.List[Scope]:
Returns:
A list of the created scope instances
"""
- if isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Query):
- # We ignore the DDL expression and build a scope for its query instead
- ddl_with = expression.args.get("with")
- expression = expression.expression
-
- # If the DDL has CTEs attached, we need to add them to the query, or
- # prepend them if the query itself already has CTEs attached to it
- if ddl_with:
- ddl_with.pop()
- query_ctes = expression.ctes
- if not query_ctes:
- expression.set("with", ddl_with)
- else:
- expression.args["with"].set("recursive", ddl_with.recursive)
- expression.args["with"].set("expressions", [*ddl_with.expressions, *query_ctes])
-
- if isinstance(expression, exp.Query):
+ if isinstance(expression, TRAVERSABLES):
return list(_traverse_scope(Scope(expression)))
-
return []
@@ -531,25 +516,37 @@ def build_scope(expression: exp.Expression) -> t.Optional[Scope]:
def _traverse_scope(scope):
- if isinstance(scope.expression, exp.Select):
+ expression = scope.expression
+
+ if isinstance(expression, exp.Select):
yield from _traverse_select(scope)
- elif isinstance(scope.expression, exp.Union):
+ elif isinstance(expression, exp.Union):
yield from _traverse_ctes(scope)
yield from _traverse_union(scope)
return
- elif isinstance(scope.expression, exp.Subquery):
+ elif isinstance(expression, exp.Subquery):
if scope.is_root:
yield from _traverse_select(scope)
else:
yield from _traverse_subqueries(scope)
- elif isinstance(scope.expression, exp.Table):
+ elif isinstance(expression, exp.Table):
yield from _traverse_tables(scope)
- elif isinstance(scope.expression, exp.UDTF):
+ elif isinstance(expression, exp.UDTF):
yield from _traverse_udtfs(scope)
+ elif isinstance(expression, exp.DDL):
+ if isinstance(expression.expression, exp.Query):
+ yield from _traverse_ctes(scope)
+ yield from _traverse_scope(Scope(expression.expression, cte_sources=scope.cte_sources))
+ return
+ elif isinstance(expression, exp.DML):
+ yield from _traverse_ctes(scope)
+ for query in find_all_in_scope(expression, exp.Query):
+ # This check ensures we don't yield the CTE queries twice
+ if not isinstance(query.parent, exp.CTE):
+ yield from _traverse_scope(Scope(query, cte_sources=scope.cte_sources))
+ return
else:
- logger.warning(
- "Cannot traverse scope %s with type '%s'", scope.expression, type(scope.expression)
- )
+ logger.warning("Cannot traverse scope %s with type '%s'", expression, type(expression))
return
yield scope
@@ -749,7 +746,7 @@ def _traverse_udtfs(scope):
for child_scope in _traverse_scope(
scope.branch(
expression,
- scope_type=ScopeType.DERIVED_TABLE,
+ scope_type=ScopeType.SUBQUERY,
outer_columns=expression.alias_column_names,
)
):
@@ -757,8 +754,7 @@ def _traverse_udtfs(scope):
top = child_scope
sources[expression.alias] = child_scope
- scope.derived_table_scopes.append(top)
- scope.table_scopes.append(top)
+ scope.subquery_scopes.append(top)
scope.sources.update(sources)