summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/merge_subqueries.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--sqlglot/optimizer/merge_subqueries.py (renamed from sqlglot/optimizer/merge_derived_tables.py)149
1 files changed, 102 insertions, 47 deletions
diff --git a/sqlglot/optimizer/merge_derived_tables.py b/sqlglot/optimizer/merge_subqueries.py
index 8b161fb..9d966b7 100644
--- a/sqlglot/optimizer/merge_derived_tables.py
+++ b/sqlglot/optimizer/merge_subqueries.py
@@ -1,72 +1,127 @@
from collections import defaultdict
from sqlglot import expressions as exp
-from sqlglot.optimizer.scope import traverse_scope
+from sqlglot.helper import find_new_name
+from sqlglot.optimizer.scope import Scope, traverse_scope
from sqlglot.optimizer.simplify import simplify
-def merge_derived_tables(expression):
+def merge_subqueries(expression, leave_tables_isolated=False):
"""
Rewrite sqlglot AST to merge derived tables into the outer query.
+ This also merges CTEs if they are selected from only once.
+
Example:
>>> import sqlglot
- >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x)")
- >>> merge_derived_tables(expression).sql()
- 'SELECT x.a FROM x'
+ >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) JOIN y")
+ >>> merge_subqueries(expression).sql()
+ 'SELECT x.a FROM x JOIN y'
+
+ If `leave_tables_isolated` is True, this will not merge inner queries into outer
+ queries if it would result in multiple table selects in a single query:
+ >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) JOIN y")
+ >>> merge_subqueries(expression, leave_tables_isolated=True).sql()
+ 'SELECT a FROM (SELECT x.a FROM x) JOIN y'
Inspired by https://dev.mysql.com/doc/refman/8.0/en/derived-table-optimization.html
Args:
expression (sqlglot.Expression): expression to optimize
+ leave_tables_isolated (bool):
Returns:
sqlglot.Expression: optimized expression
"""
+ merge_ctes(expression, leave_tables_isolated)
+ merge_derived_tables(expression, leave_tables_isolated)
+ return expression
+
+
+# If a derived table has these Select args, it can't be merged
+UNMERGABLE_ARGS = set(exp.Select.arg_types) - {
+ "expressions",
+ "from",
+ "joins",
+ "where",
+ "order",
+}
+
+
+def merge_ctes(expression, leave_tables_isolated=False):
+ scopes = traverse_scope(expression)
+
+ # All places where we select from CTEs.
+ # We key on the CTE scope so we can detect CTES that are selected from multiple times.
+ cte_selections = defaultdict(list)
+ for outer_scope in scopes:
+ for table, inner_scope in outer_scope.selected_sources.values():
+ if isinstance(inner_scope, Scope) and inner_scope.is_cte:
+ cte_selections[id(inner_scope)].append(
+ (
+ outer_scope,
+ inner_scope,
+ table,
+ )
+ )
+
+ singular_cte_selections = [v[0] for k, v in cte_selections.items() if len(v) == 1]
+ for outer_scope, inner_scope, table in singular_cte_selections:
+ inner_select = inner_scope.expression.unnest()
+ if _mergeable(outer_scope, inner_select, leave_tables_isolated):
+ from_or_join = table.find_ancestor(exp.From, exp.Join)
+
+ node_to_replace = table
+ if isinstance(node_to_replace.parent, exp.Alias):
+ node_to_replace = node_to_replace.parent
+ alias = node_to_replace.alias
+ else:
+ alias = table.name
+
+ _rename_inner_sources(outer_scope, inner_scope, alias)
+ _merge_from(outer_scope, inner_scope, node_to_replace, alias)
+ _merge_joins(outer_scope, inner_scope, from_or_join)
+ _merge_expressions(outer_scope, inner_scope, alias)
+ _merge_where(outer_scope, inner_scope, from_or_join)
+ _merge_order(outer_scope, inner_scope)
+ _pop_cte(inner_scope)
+
+
+def merge_derived_tables(expression, leave_tables_isolated=False):
for outer_scope in traverse_scope(expression):
for subquery in outer_scope.derived_tables:
inner_select = subquery.unnest()
- if (
- isinstance(outer_scope.expression, exp.Select)
- and isinstance(inner_select, exp.Select)
- and _mergeable(inner_select)
- ):
+ if _mergeable(outer_scope, inner_select, leave_tables_isolated):
alias = subquery.alias_or_name
from_or_join = subquery.find_ancestor(exp.From, exp.Join)
inner_scope = outer_scope.sources[alias]
_rename_inner_sources(outer_scope, inner_scope, alias)
- _merge_from(outer_scope, inner_scope, subquery)
+ _merge_from(outer_scope, inner_scope, subquery, alias)
_merge_joins(outer_scope, inner_scope, from_or_join)
_merge_expressions(outer_scope, inner_scope, alias)
_merge_where(outer_scope, inner_scope, from_or_join)
_merge_order(outer_scope, inner_scope)
- return expression
-# If a derived table has these Select args, it can't be merged
-UNMERGABLE_ARGS = set(exp.Select.arg_types) - {
- "expressions",
- "from",
- "joins",
- "where",
- "order",
-}
-
-
-def _mergeable(inner_select):
+def _mergeable(outer_scope, inner_select, leave_tables_isolated):
"""
Return True if `inner_select` can be merged into outer query.
Args:
+ outer_scope (Scope)
inner_select (exp.Select)
+ leave_tables_isolated (bool)
Returns:
bool: True if can be merged
"""
return (
- isinstance(inner_select, exp.Select)
+ isinstance(outer_scope.expression, exp.Select)
+ and isinstance(inner_select, exp.Select)
+ and isinstance(inner_select, exp.Select)
and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS)
and inner_select.args.get("from")
and not any(e.find(exp.AggFunc, exp.Select) for e in inner_select.expressions)
+ and not (leave_tables_isolated and len(outer_scope.selected_sources) > 1)
)
@@ -84,7 +139,7 @@ def _rename_inner_sources(outer_scope, inner_scope, alias):
conflicts = conflicts - {alias}
for conflict in conflicts:
- new_name = _find_new_name(taken, conflict)
+ new_name = find_new_name(taken, conflict)
source, _ = inner_scope.selected_sources[conflict]
new_alias = exp.to_identifier(new_name)
@@ -102,34 +157,19 @@ def _rename_inner_sources(outer_scope, inner_scope, alias):
inner_scope.rename_source(conflict, new_name)
-def _find_new_name(taken, base):
- """
- Searches for a new source name.
-
- Args:
- taken (set[str]): set of taken names
- base (str): base name to alter
- """
- i = 2
- new = f"{base}_{i}"
- while new in taken:
- i += 1
- new = f"{base}_{i}"
- return new
-
-
-def _merge_from(outer_scope, inner_scope, subquery):
+def _merge_from(outer_scope, inner_scope, node_to_replace, alias):
"""
Merge FROM clause of inner query into outer query.
Args:
outer_scope (sqlglot.optimizer.scope.Scope)
inner_scope (sqlglot.optimizer.scope.Scope)
- subquery (exp.Subquery)
+ node_to_replace (exp.Subquery|exp.Table)
+ alias (str)
"""
new_subquery = inner_scope.expression.args.get("from").expressions[0]
- subquery.replace(new_subquery)
- outer_scope.remove_source(subquery.alias_or_name)
+ node_to_replace.replace(new_subquery)
+ outer_scope.remove_source(alias)
outer_scope.add_source(new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name])
@@ -176,7 +216,7 @@ def _merge_expressions(outer_scope, inner_scope, alias):
inner_scope (sqlglot.optimizer.scope.Scope)
alias (str)
"""
- # Collect all columns that for the alias of the inner query
+ # Collect all columns that reference the alias of the inner query
outer_columns = defaultdict(list)
for column in outer_scope.columns:
if column.table == alias:
@@ -205,7 +245,7 @@ def _merge_where(outer_scope, inner_scope, from_or_join):
if not where or not where.this:
return
- if isinstance(from_or_join, exp.Join) and from_or_join.side:
+ if isinstance(from_or_join, exp.Join):
# Merge predicates from an outer join to the ON clause
from_or_join.on(where.this, copy=False)
from_or_join.set("on", simplify(from_or_join.args.get("on")))
@@ -230,3 +270,18 @@ def _merge_order(outer_scope, inner_scope):
return
outer_scope.expression.set("order", inner_scope.expression.args.get("order"))
+
+
+def _pop_cte(inner_scope):
+ """
+ Remove CTE from the AST.
+
+ Args:
+ inner_scope (sqlglot.optimizer.scope.Scope)
+ """
+ cte = inner_scope.expression.parent
+ with_ = cte.parent
+ if len(with_.expressions) == 1:
+ with_.pop()
+ else:
+ cte.pop()