summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/merge_subqueries.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer/merge_subqueries.py')
-rw-r--r--sqlglot/optimizer/merge_subqueries.py54
1 files changed, 42 insertions, 12 deletions
diff --git a/sqlglot/optimizer/merge_subqueries.py b/sqlglot/optimizer/merge_subqueries.py
index 9ae4966..16aaf17 100644
--- a/sqlglot/optimizer/merge_subqueries.py
+++ b/sqlglot/optimizer/merge_subqueries.py
@@ -67,11 +67,9 @@ def merge_ctes(expression, leave_tables_isolated=False):
singular_cte_selections = [v[0] for k, v in cte_selections.items() if len(v) == 1]
for outer_scope, inner_scope, table in singular_cte_selections:
- inner_select = inner_scope.expression.unnest()
from_or_join = table.find_ancestor(exp.From, exp.Join)
- if _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
+ if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
alias = table.alias_or_name
-
_rename_inner_sources(outer_scope, inner_scope, alias)
_merge_from(outer_scope, inner_scope, table, alias)
_merge_expressions(outer_scope, inner_scope, alias)
@@ -80,18 +78,17 @@ def merge_ctes(expression, leave_tables_isolated=False):
_merge_order(outer_scope, inner_scope)
_merge_hints(outer_scope, inner_scope)
_pop_cte(inner_scope)
+ outer_scope.clear_cache()
return expression
def merge_derived_tables(expression, leave_tables_isolated=False):
for outer_scope in traverse_scope(expression):
for subquery in outer_scope.derived_tables:
- inner_select = subquery.unnest()
from_or_join = subquery.find_ancestor(exp.From, exp.Join)
- if _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
- alias = subquery.alias_or_name
- inner_scope = outer_scope.sources[alias]
-
+ alias = subquery.alias_or_name
+ inner_scope = outer_scope.sources[alias]
+ if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
_rename_inner_sources(outer_scope, inner_scope, alias)
_merge_from(outer_scope, inner_scope, subquery, alias)
_merge_expressions(outer_scope, inner_scope, alias)
@@ -99,21 +96,23 @@ def merge_derived_tables(expression, leave_tables_isolated=False):
_merge_where(outer_scope, inner_scope, from_or_join)
_merge_order(outer_scope, inner_scope)
_merge_hints(outer_scope, inner_scope)
+ outer_scope.clear_cache()
return expression
-def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
+def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
"""
Return True if `inner_select` can be merged into outer query.
Args:
outer_scope (Scope)
- inner_select (exp.Select)
+ inner_scope (Scope)
leave_tables_isolated (bool)
from_or_join (exp.From|exp.Join)
Returns:
bool: True if can be merged
"""
+ inner_select = inner_scope.expression.unnest()
def _is_a_window_expression_in_unmergable_operation():
window_expressions = inner_select.find_all(exp.Window)
@@ -133,10 +132,40 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
]
return any(window_expressions_in_unmergable)
+ def _outer_select_joins_on_inner_select_join():
+ """
+ All columns from the inner select in the ON clause must be from the first FROM table.
+
+ That is, this can be merged:
+ SELECT * FROM x JOIN (SELECT y.a AS a FROM y JOIN z) AS q ON x.a = q.a
+ ^^^ ^
+ But this can't:
+ SELECT * FROM x JOIN (SELECT z.a AS a FROM y JOIN z) AS q ON x.a = q.a
+ ^^^ ^
+ """
+ if not isinstance(from_or_join, exp.Join):
+ return False
+
+ alias = from_or_join.this.alias_or_name
+
+ on = from_or_join.args.get("on")
+ if not on:
+ return False
+ selections = [c.name for c in on.find_all(exp.Column) if c.table == alias]
+ inner_from = inner_scope.expression.args.get("from")
+ if not inner_from:
+ return False
+ inner_from_table = inner_from.expressions[0].alias_or_name
+ inner_projections = {s.alias_or_name: s for s in inner_scope.selects}
+ return any(
+ col.table != inner_from_table
+ for selection in selections
+ for col in inner_projections[selection].find_all(exp.Column)
+ )
+
return (
isinstance(outer_scope.expression, exp.Select)
and isinstance(inner_select, exp.Select)
- and isinstance(inner_select, exp.Select)
and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS)
and inner_select.args.get("from")
and not any(e.find(exp.AggFunc, exp.Select) for e in inner_select.expressions)
@@ -153,6 +182,7 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
j.side in {"FULL", "RIGHT"} for j in outer_scope.expression.args.get("joins", [])
)
)
+ and not _outer_select_joins_on_inner_select_join()
and not _is_a_window_expression_in_unmergable_operation()
)
@@ -168,7 +198,7 @@ def _rename_inner_sources(outer_scope, inner_scope, alias):
"""
taken = set(outer_scope.selected_sources)
conflicts = taken.intersection(set(inner_scope.selected_sources))
- conflicts = conflicts - {alias}
+ conflicts -= {alias}
for conflict in conflicts:
new_name = find_new_name(taken, conflict)