summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/eliminate_subqueries.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer/eliminate_subqueries.py')
-rw-r--r--sqlglot/optimizer/eliminate_subqueries.py59
1 files changed, 50 insertions, 9 deletions
diff --git a/sqlglot/optimizer/eliminate_subqueries.py b/sqlglot/optimizer/eliminate_subqueries.py
index 8704e90..39e252c 100644
--- a/sqlglot/optimizer/eliminate_subqueries.py
+++ b/sqlglot/optimizer/eliminate_subqueries.py
@@ -68,6 +68,9 @@ def eliminate_subqueries(expression):
for cte_scope in root.cte_scopes:
# Append all the new CTEs from this existing CTE
for scope in cte_scope.traverse():
+ if scope is cte_scope:
+ # Don't try to eliminate this CTE itself
+ continue
new_cte = _eliminate(scope, existing_ctes, taken)
if new_cte:
new_ctes.append(new_cte)
@@ -97,6 +100,9 @@ def _eliminate(scope, existing_ctes, taken):
if scope.is_derived_table and not isinstance(scope.expression, exp.UDTF):
return _eliminate_derived_table(scope, existing_ctes, taken)
+ if scope.is_cte:
+ return _eliminate_cte(scope, existing_ctes, taken)
+
def _eliminate_union(scope, existing_ctes, taken):
duplicate_cte_alias = existing_ctes.get(scope.expression)
@@ -127,26 +133,61 @@ def _eliminate_union(scope, existing_ctes, taken):
def _eliminate_derived_table(scope, existing_ctes, taken):
+ parent = scope.expression.parent
+ name, cte = _new_cte(scope, existing_ctes, taken)
+
+ table = exp.alias_(exp.table_(name), alias=parent.alias or name)
+ parent.replace(table)
+
+ return cte
+
+
+def _eliminate_cte(scope, existing_ctes, taken):
+ parent = scope.expression.parent
+ name, cte = _new_cte(scope, existing_ctes, taken)
+
+ with_ = parent.parent
+ parent.pop()
+ if not with_.expressions:
+ with_.pop()
+
+ # Rename references to this CTE
+ for child_scope in scope.parent.traverse():
+ for table, source in child_scope.selected_sources.values():
+ if source is scope:
+ new_table = exp.alias_(exp.table_(name), alias=table.alias_or_name)
+ table.replace(new_table)
+
+ return cte
+
+
+def _new_cte(scope, existing_ctes, taken):
+ """
+ Returns:
+ tuple of (name, cte)
+ where `name` is a new name for this CTE in the root scope and `cte` is a new CTE instance.
+ If this CTE duplicates an existing CTE, `cte` will be None.
+ """
duplicate_cte_alias = existing_ctes.get(scope.expression)
parent = scope.expression.parent
- name = alias = parent.alias
+ name = parent.alias
- if not alias:
- name = alias = find_new_name(taken=taken, base="cte")
+ if not name:
+ name = find_new_name(taken=taken, base="cte")
if duplicate_cte_alias:
name = duplicate_cte_alias
- elif taken.get(alias):
- name = find_new_name(taken=taken, base=alias)
+ elif taken.get(name):
+ name = find_new_name(taken=taken, base=name)
taken[name] = scope
- table = exp.alias_(exp.table_(name), alias=alias)
- parent.replace(table)
-
if not duplicate_cte_alias:
existing_ctes[scope.expression] = name
- return exp.CTE(
+ cte = exp.CTE(
this=scope.expression,
alias=exp.TableAlias(this=exp.to_identifier(name)),
)
+ else:
+ cte = None
+ return name, cte