diff options
Diffstat (limited to 'sqlglot/optimizer/merge_subqueries.py')
-rw-r--r-- | sqlglot/optimizer/merge_subqueries.py | 54 |
1 files changed, 42 insertions, 12 deletions
diff --git a/sqlglot/optimizer/merge_subqueries.py b/sqlglot/optimizer/merge_subqueries.py index 9ae4966..16aaf17 100644 --- a/sqlglot/optimizer/merge_subqueries.py +++ b/sqlglot/optimizer/merge_subqueries.py @@ -67,11 +67,9 @@ def merge_ctes(expression, leave_tables_isolated=False): singular_cte_selections = [v[0] for k, v in cte_selections.items() if len(v) == 1] for outer_scope, inner_scope, table in singular_cte_selections: - inner_select = inner_scope.expression.unnest() from_or_join = table.find_ancestor(exp.From, exp.Join) - if _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join): + if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join): alias = table.alias_or_name - _rename_inner_sources(outer_scope, inner_scope, alias) _merge_from(outer_scope, inner_scope, table, alias) _merge_expressions(outer_scope, inner_scope, alias) @@ -80,18 +78,17 @@ def merge_ctes(expression, leave_tables_isolated=False): _merge_order(outer_scope, inner_scope) _merge_hints(outer_scope, inner_scope) _pop_cte(inner_scope) + outer_scope.clear_cache() return expression def merge_derived_tables(expression, leave_tables_isolated=False): for outer_scope in traverse_scope(expression): for subquery in outer_scope.derived_tables: - inner_select = subquery.unnest() from_or_join = subquery.find_ancestor(exp.From, exp.Join) - if _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join): - alias = subquery.alias_or_name - inner_scope = outer_scope.sources[alias] - + alias = subquery.alias_or_name + inner_scope = outer_scope.sources[alias] + if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join): _rename_inner_sources(outer_scope, inner_scope, alias) _merge_from(outer_scope, inner_scope, subquery, alias) _merge_expressions(outer_scope, inner_scope, alias) @@ -99,21 +96,23 @@ def merge_derived_tables(expression, leave_tables_isolated=False): _merge_where(outer_scope, inner_scope, from_or_join) _merge_order(outer_scope, inner_scope) _merge_hints(outer_scope, inner_scope) + outer_scope.clear_cache() return expression -def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join): +def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join): """ Return True if `inner_select` can be merged into outer query. Args: outer_scope (Scope) - inner_select (exp.Select) + inner_scope (Scope) leave_tables_isolated (bool) from_or_join (exp.From|exp.Join) Returns: bool: True if can be merged """ + inner_select = inner_scope.expression.unnest() def _is_a_window_expression_in_unmergable_operation(): window_expressions = inner_select.find_all(exp.Window) @@ -133,10 +132,40 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join): ] return any(window_expressions_in_unmergable) + def _outer_select_joins_on_inner_select_join(): + """ + All columns from the inner select in the ON clause must be from the first FROM table. + + That is, this can be merged: + SELECT * FROM x JOIN (SELECT y.a AS a FROM y JOIN z) AS q ON x.a = q.a + ^^^ ^ + But this can't: + SELECT * FROM x JOIN (SELECT z.a AS a FROM y JOIN z) AS q ON x.a = q.a + ^^^ ^ + """ + if not isinstance(from_or_join, exp.Join): + return False + + alias = from_or_join.this.alias_or_name + + on = from_or_join.args.get("on") + if not on: + return False + selections = [c.name for c in on.find_all(exp.Column) if c.table == alias] + inner_from = inner_scope.expression.args.get("from") + if not inner_from: + return False + inner_from_table = inner_from.expressions[0].alias_or_name + inner_projections = {s.alias_or_name: s for s in inner_scope.selects} + return any( + col.table != inner_from_table + for selection in selections + for col in inner_projections[selection].find_all(exp.Column) + ) + return ( isinstance(outer_scope.expression, exp.Select) and isinstance(inner_select, exp.Select) - and isinstance(inner_select, exp.Select) and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS) and inner_select.args.get("from") and not any(e.find(exp.AggFunc, exp.Select) for e in inner_select.expressions) @@ -153,6 +182,7 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join): j.side in {"FULL", "RIGHT"} for j in outer_scope.expression.args.get("joins", []) ) ) + and not _outer_select_joins_on_inner_select_join() and not _is_a_window_expression_in_unmergable_operation() ) @@ -168,7 +198,7 @@ def _rename_inner_sources(outer_scope, inner_scope, alias): """ taken = set(outer_scope.selected_sources) conflicts = taken.intersection(set(inner_scope.selected_sources)) - conflicts = conflicts - {alias} + conflicts -= {alias} for conflict in conflicts: new_name = find_new_name(taken, conflict) |