diff options
Diffstat (limited to '')
-rw-r--r-- | sqlglot/optimizer/eliminate_subqueries.py | 59 |
1 files changed, 50 insertions, 9 deletions
diff --git a/sqlglot/optimizer/eliminate_subqueries.py b/sqlglot/optimizer/eliminate_subqueries.py index 8704e90..39e252c 100644 --- a/sqlglot/optimizer/eliminate_subqueries.py +++ b/sqlglot/optimizer/eliminate_subqueries.py @@ -68,6 +68,9 @@ def eliminate_subqueries(expression): for cte_scope in root.cte_scopes: # Append all the new CTEs from this existing CTE for scope in cte_scope.traverse(): + if scope is cte_scope: + # Don't try to eliminate this CTE itself + continue new_cte = _eliminate(scope, existing_ctes, taken) if new_cte: new_ctes.append(new_cte) @@ -97,6 +100,9 @@ def _eliminate(scope, existing_ctes, taken): if scope.is_derived_table and not isinstance(scope.expression, exp.UDTF): return _eliminate_derived_table(scope, existing_ctes, taken) + if scope.is_cte: + return _eliminate_cte(scope, existing_ctes, taken) + def _eliminate_union(scope, existing_ctes, taken): duplicate_cte_alias = existing_ctes.get(scope.expression) @@ -127,26 +133,61 @@ def _eliminate_union(scope, existing_ctes, taken): def _eliminate_derived_table(scope, existing_ctes, taken): + parent = scope.expression.parent + name, cte = _new_cte(scope, existing_ctes, taken) + + table = exp.alias_(exp.table_(name), alias=parent.alias or name) + parent.replace(table) + + return cte + + +def _eliminate_cte(scope, existing_ctes, taken): + parent = scope.expression.parent + name, cte = _new_cte(scope, existing_ctes, taken) + + with_ = parent.parent + parent.pop() + if not with_.expressions: + with_.pop() + + # Rename references to this CTE + for child_scope in scope.parent.traverse(): + for table, source in child_scope.selected_sources.values(): + if source is scope: + new_table = exp.alias_(exp.table_(name), alias=table.alias_or_name) + table.replace(new_table) + + return cte + + +def _new_cte(scope, existing_ctes, taken): + """ + Returns: + tuple of (name, cte) + where `name` is a new name for this CTE in the root scope and `cte` is a new CTE instance. + If this CTE duplicates an existing CTE, `cte` will be None. + """ duplicate_cte_alias = existing_ctes.get(scope.expression) parent = scope.expression.parent - name = alias = parent.alias + name = parent.alias - if not alias: - name = alias = find_new_name(taken=taken, base="cte") + if not name: + name = find_new_name(taken=taken, base="cte") if duplicate_cte_alias: name = duplicate_cte_alias - elif taken.get(alias): - name = find_new_name(taken=taken, base=alias) + elif taken.get(name): + name = find_new_name(taken=taken, base=name) taken[name] = scope - table = exp.alias_(exp.table_(name), alias=alias) - parent.replace(table) - if not duplicate_cte_alias: existing_ctes[scope.expression] = name - return exp.CTE( + cte = exp.CTE( this=scope.expression, alias=exp.TableAlias(this=exp.to_identifier(name)), ) + else: + cte = None + return name, cte |