summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/eliminate_subqueries.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer/eliminate_subqueries.py')
-rw-r--r--sqlglot/optimizer/eliminate_subqueries.py144
1 files changed, 120 insertions, 24 deletions
diff --git a/sqlglot/optimizer/eliminate_subqueries.py b/sqlglot/optimizer/eliminate_subqueries.py
index 4bfb733..38e1299 100644
--- a/sqlglot/optimizer/eliminate_subqueries.py
+++ b/sqlglot/optimizer/eliminate_subqueries.py
@@ -1,48 +1,144 @@
import itertools
-from sqlglot import alias, exp, select, table
-from sqlglot.optimizer.scope import traverse_scope
+from sqlglot import expressions as exp
+from sqlglot.helper import find_new_name
+from sqlglot.optimizer.scope import build_scope
from sqlglot.optimizer.simplify import simplify
def eliminate_subqueries(expression):
"""
- Rewrite duplicate subqueries from sqlglot AST.
+ Rewrite subqueries as CTES, deduplicating if possible.
Example:
>>> import sqlglot
- >>> expression = sqlglot.parse_one("SELECT 1 AS x, 2 AS y UNION ALL SELECT 1 AS x, 2 AS y")
+ >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y")
>>> eliminate_subqueries(expression).sql()
- 'WITH _e_0 AS (SELECT 1 AS x, 2 AS y) SELECT * FROM _e_0 UNION ALL SELECT * FROM _e_0'
+ 'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y'
+
+ This also deduplicates common subqueries:
+ >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y JOIN (SELECT * FROM x) AS z")
+ >>> eliminate_subqueries(expression).sql()
+ 'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y JOIN y AS z'
Args:
- expression (sqlglot.Expression): expression to qualify
- schema (dict|sqlglot.optimizer.Schema): Database schema
+ expression (sqlglot.Expression): expression
Returns:
- sqlglot.Expression: qualified expression
+ sqlglot.Expression: expression
"""
+ if isinstance(expression, exp.Subquery):
+ # It's possible to have subqueries at the root, e.g. (SELECT * FROM x) LIMIT 1
+ eliminate_subqueries(expression.this)
+ return expression
+
expression = simplify(expression)
- queries = {}
+ root = build_scope(expression)
+
+ # Map of alias->Scope|Table
+ # These are all aliases that are already used in the expression.
+ # We don't want to create new CTEs that conflict with these names.
+ taken = {}
+
+ # All CTE aliases in the root scope are taken
+ for scope in root.cte_scopes:
+ taken[scope.expression.parent.alias] = scope
+
+ # All table names are taken
+ for scope in root.traverse():
+ taken.update({source.name: source for _, source in scope.sources.items() if isinstance(source, exp.Table)})
- for scope in traverse_scope(expression):
- query = scope.expression
- queries[query] = queries.get(query, []) + [query]
+ # Map of Expression->alias
+ # Existing CTES in the root expression. We'll use this for deduplication.
+ existing_ctes = {}
- sequence = itertools.count()
+ with_ = root.expression.args.get("with")
+ if with_:
+ for cte in with_.expressions:
+ existing_ctes[cte.this] = cte.alias
+ new_ctes = []
- for query, duplicates in queries.items():
- if len(duplicates) == 1:
- continue
+ # We're adding more CTEs, but we want to maintain the DAG order.
+ # Derived tables within an existing CTE need to come before the existing CTE.
+ for cte_scope in root.cte_scopes:
+ # Append all the new CTEs from this existing CTE
+ for scope in cte_scope.traverse():
+ new_cte = _eliminate(scope, existing_ctes, taken)
+ if new_cte:
+ new_ctes.append(new_cte)
- alias_ = f"_e_{next(sequence)}"
+ # Append the existing CTE itself
+ new_ctes.append(cte_scope.expression.parent)
- for dup in duplicates:
- parent = dup.parent
- if isinstance(parent, exp.Subquery):
- parent.replace(alias(table(alias_), parent.alias_or_name, table=True))
- elif isinstance(parent, exp.Union):
- dup.replace(select("*").from_(alias_))
+ # Now append the rest
+ for scope in itertools.chain(root.union_scopes, root.subquery_scopes, root.derived_table_scopes):
+ for child_scope in scope.traverse():
+ new_cte = _eliminate(child_scope, existing_ctes, taken)
+ if new_cte:
+ new_ctes.append(new_cte)
- expression.with_(alias_, as_=query, copy=False)
+ if new_ctes:
+ expression.set("with", exp.With(expressions=new_ctes))
return expression
+
+
+def _eliminate(scope, existing_ctes, taken):
+ if scope.is_union:
+ return _eliminate_union(scope, existing_ctes, taken)
+
+ if scope.is_derived_table and not isinstance(scope.expression, (exp.Unnest, exp.Lateral)):
+ return _eliminate_derived_table(scope, existing_ctes, taken)
+
+
+def _eliminate_union(scope, existing_ctes, taken):
+ duplicate_cte_alias = existing_ctes.get(scope.expression)
+
+ alias = duplicate_cte_alias or find_new_name(taken=taken, base="cte")
+
+ taken[alias] = scope
+
+ # Try to maintain the selections
+ expressions = scope.expression.args.get("expressions")
+ selects = [
+ exp.alias_(exp.column(e.alias_or_name, table=alias), alias=e.alias_or_name)
+ for e in expressions
+ if e.alias_or_name
+ ]
+ # If not all selections have an alias, just select *
+ if len(selects) != len(expressions):
+ selects = ["*"]
+
+ scope.expression.replace(exp.select(*selects).from_(exp.alias_(exp.table_(alias), alias=alias)))
+
+ if not duplicate_cte_alias:
+ existing_ctes[scope.expression] = alias
+ return exp.CTE(
+ this=scope.expression,
+ alias=exp.TableAlias(this=exp.to_identifier(alias)),
+ )
+
+
+def _eliminate_derived_table(scope, existing_ctes, taken):
+ duplicate_cte_alias = existing_ctes.get(scope.expression)
+ parent = scope.expression.parent
+ name = alias = parent.alias
+
+ if not alias:
+ name = alias = find_new_name(taken=taken, base="cte")
+
+ if duplicate_cte_alias:
+ name = duplicate_cte_alias
+ elif taken.get(alias):
+ name = find_new_name(taken=taken, base=alias)
+
+ taken[name] = scope
+
+ table = exp.alias_(exp.table_(name), alias=alias)
+ parent.replace(table)
+
+ if not duplicate_cte_alias:
+ existing_ctes[scope.expression] = name
+ return exp.CTE(
+ this=scope.expression,
+ alias=exp.TableAlias(this=exp.to_identifier(name)),
+ )