diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-05-04 16:13:01 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-05-04 16:13:01 +0000 |
commit | a7044b672667f2a0b48bd0b326b5a55b0815ef79 (patch) | |
tree | 4fb5238d47fb4709d47f766a74b8bbaa9c6f17d8 /sqlglot/optimizer/scope.py | |
parent | Releasing debian version 23.12.1-1. (diff) | |
download | sqlglot-a7044b672667f2a0b48bd0b326b5a55b0815ef79.tar.xz sqlglot-a7044b672667f2a0b48bd0b326b5a55b0815ef79.zip |
Merging upstream version 23.13.1.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/optimizer/scope.py')
-rw-r--r-- | sqlglot/optimizer/scope.py | 54 |
1 files changed, 25 insertions, 29 deletions
diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index c589e24..cff75e4 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -12,6 +12,8 @@ from sqlglot.helper import ensure_collection, find_new_name, seq_get logger = logging.getLogger("sqlglot") +TRAVERSABLES = (exp.Query, exp.DDL, exp.DML) + class ScopeType(Enum): ROOT = auto() @@ -495,25 +497,8 @@ def traverse_scope(expression: exp.Expression) -> t.List[Scope]: Returns: A list of the created scope instances """ - if isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Query): - # We ignore the DDL expression and build a scope for its query instead - ddl_with = expression.args.get("with") - expression = expression.expression - - # If the DDL has CTEs attached, we need to add them to the query, or - # prepend them if the query itself already has CTEs attached to it - if ddl_with: - ddl_with.pop() - query_ctes = expression.ctes - if not query_ctes: - expression.set("with", ddl_with) - else: - expression.args["with"].set("recursive", ddl_with.recursive) - expression.args["with"].set("expressions", [*ddl_with.expressions, *query_ctes]) - - if isinstance(expression, exp.Query): + if isinstance(expression, TRAVERSABLES): return list(_traverse_scope(Scope(expression))) - return [] @@ -531,25 +516,37 @@ def build_scope(expression: exp.Expression) -> t.Optional[Scope]: def _traverse_scope(scope): - if isinstance(scope.expression, exp.Select): + expression = scope.expression + + if isinstance(expression, exp.Select): yield from _traverse_select(scope) - elif isinstance(scope.expression, exp.Union): + elif isinstance(expression, exp.Union): yield from _traverse_ctes(scope) yield from _traverse_union(scope) return - elif isinstance(scope.expression, exp.Subquery): + elif isinstance(expression, exp.Subquery): if scope.is_root: yield from _traverse_select(scope) else: yield from _traverse_subqueries(scope) - elif isinstance(scope.expression, exp.Table): + elif isinstance(expression, exp.Table): yield from _traverse_tables(scope) - elif isinstance(scope.expression, exp.UDTF): + elif isinstance(expression, exp.UDTF): yield from _traverse_udtfs(scope) + elif isinstance(expression, exp.DDL): + if isinstance(expression.expression, exp.Query): + yield from _traverse_ctes(scope) + yield from _traverse_scope(Scope(expression.expression, cte_sources=scope.cte_sources)) + return + elif isinstance(expression, exp.DML): + yield from _traverse_ctes(scope) + for query in find_all_in_scope(expression, exp.Query): + # This check ensures we don't yield the CTE queries twice + if not isinstance(query.parent, exp.CTE): + yield from _traverse_scope(Scope(query, cte_sources=scope.cte_sources)) + return else: - logger.warning( - "Cannot traverse scope %s with type '%s'", scope.expression, type(scope.expression) - ) + logger.warning("Cannot traverse scope %s with type '%s'", expression, type(expression)) return yield scope @@ -749,7 +746,7 @@ def _traverse_udtfs(scope): for child_scope in _traverse_scope( scope.branch( expression, - scope_type=ScopeType.DERIVED_TABLE, + scope_type=ScopeType.SUBQUERY, outer_columns=expression.alias_column_names, ) ): @@ -757,8 +754,7 @@ def _traverse_udtfs(scope): top = child_scope sources[expression.alias] = child_scope - scope.derived_table_scopes.append(top) - scope.table_scopes.append(top) + scope.subquery_scopes.append(top) scope.sources.update(sources) |