diff options
Diffstat (limited to 'sqlglot/optimizer/eliminate_subqueries.py')
-rw-r--r-- | sqlglot/optimizer/eliminate_subqueries.py | 19 |
1 files changed, 14 insertions, 5 deletions
diff --git a/sqlglot/optimizer/eliminate_subqueries.py b/sqlglot/optimizer/eliminate_subqueries.py index a39fe96..84f50e9 100644 --- a/sqlglot/optimizer/eliminate_subqueries.py +++ b/sqlglot/optimizer/eliminate_subqueries.py @@ -16,9 +16,9 @@ def eliminate_subqueries(expression): 'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y' This also deduplicates common subqueries: - >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y JOIN (SELECT * FROM x) AS z") + >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y CROSS JOIN (SELECT * FROM x) AS z") >>> eliminate_subqueries(expression).sql() - 'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y JOIN y AS z' + 'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y CROSS JOIN y AS z' Args: expression (sqlglot.Expression): expression @@ -32,6 +32,9 @@ def eliminate_subqueries(expression): root = build_scope(expression) + if not root: + return expression + # Map of alias->Scope|Table # These are all aliases that are already used in the expression. # We don't want to create new CTEs that conflict with these names. @@ -112,7 +115,7 @@ def _eliminate_union(scope, existing_ctes, taken): # Try to maintain the selections expressions = scope.selects selects = [ - exp.alias_(exp.column(e.alias_or_name, table=alias), alias=e.alias_or_name) + exp.alias_(exp.column(e.alias_or_name, table=alias), alias=e.alias_or_name, copy=False) for e in expressions if e.alias_or_name ] @@ -120,7 +123,9 @@ def _eliminate_union(scope, existing_ctes, taken): if len(selects) != len(expressions): selects = ["*"] - scope.expression.replace(exp.select(*selects).from_(exp.alias_(exp.table_(alias), alias=alias))) + scope.expression.replace( + exp.select(*selects).from_(exp.alias_(exp.table_(alias), alias=alias, copy=False)) + ) if not duplicate_cte_alias: existing_ctes[scope.expression] = alias @@ -131,6 +136,10 @@ def _eliminate_union(scope, existing_ctes, taken): def _eliminate_derived_table(scope, existing_ctes, taken): + # This ensures we don't drop the "pivot" arg from a pivoted subquery + if scope.parent.pivots: + return None + parent = scope.expression.parent name, cte = _new_cte(scope, existing_ctes, taken) @@ -153,7 +162,7 @@ def _eliminate_cte(scope, existing_ctes, taken): for child_scope in scope.parent.traverse(): for table, source in child_scope.selected_sources.values(): if source is scope: - new_table = exp.alias_(exp.table_(name), alias=table.alias_or_name) + new_table = exp.alias_(exp.table_(name), alias=table.alias_or_name, copy=False) table.replace(new_table) return cte |