diff options
Diffstat (limited to 'sqlglot/optimizer/merge_subqueries.py')
-rw-r--r-- | sqlglot/optimizer/merge_subqueries.py | 44 |
1 files changed, 38 insertions, 6 deletions
diff --git a/sqlglot/optimizer/merge_subqueries.py b/sqlglot/optimizer/merge_subqueries.py index d29c22b..3e435f5 100644 --- a/sqlglot/optimizer/merge_subqueries.py +++ b/sqlglot/optimizer/merge_subqueries.py @@ -44,6 +44,7 @@ UNMERGABLE_ARGS = set(exp.Select.arg_types) - { "joins", "where", "order", + "hint", } @@ -67,21 +68,22 @@ def merge_ctes(expression, leave_tables_isolated=False): singular_cte_selections = [v[0] for k, v in cte_selections.items() if len(v) == 1] for outer_scope, inner_scope, table in singular_cte_selections: inner_select = inner_scope.expression.unnest() - if _mergeable(outer_scope, inner_select, leave_tables_isolated): - from_or_join = table.find_ancestor(exp.From, exp.Join) - + from_or_join = table.find_ancestor(exp.From, exp.Join) + if _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join): node_to_replace = table if isinstance(node_to_replace.parent, exp.Alias): node_to_replace = node_to_replace.parent alias = node_to_replace.alias else: alias = table.name + _rename_inner_sources(outer_scope, inner_scope, alias) _merge_from(outer_scope, inner_scope, node_to_replace, alias) _merge_expressions(outer_scope, inner_scope, alias) _merge_joins(outer_scope, inner_scope, from_or_join) _merge_where(outer_scope, inner_scope, from_or_join) _merge_order(outer_scope, inner_scope) + _merge_hints(outer_scope, inner_scope) _pop_cte(inner_scope) return expression @@ -90,9 +92,9 @@ def merge_derived_tables(expression, leave_tables_isolated=False): for outer_scope in traverse_scope(expression): for subquery in outer_scope.derived_tables: inner_select = subquery.unnest() - if _mergeable(outer_scope, inner_select, leave_tables_isolated): + from_or_join = subquery.find_ancestor(exp.From, exp.Join) + if _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join): alias = subquery.alias_or_name - from_or_join = subquery.find_ancestor(exp.From, exp.Join) inner_scope = outer_scope.sources[alias] _rename_inner_sources(outer_scope, inner_scope, alias) @@ -101,10 +103,11 @@ def merge_derived_tables(expression, leave_tables_isolated=False): _merge_joins(outer_scope, inner_scope, from_or_join) _merge_where(outer_scope, inner_scope, from_or_join) _merge_order(outer_scope, inner_scope) + _merge_hints(outer_scope, inner_scope) return expression -def _mergeable(outer_scope, inner_select, leave_tables_isolated): +def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join): """ Return True if `inner_select` can be merged into outer query. @@ -112,6 +115,7 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated): outer_scope (Scope) inner_select (exp.Select) leave_tables_isolated (bool) + from_or_join (exp.From|exp.Join) Returns: bool: True if can be merged """ @@ -123,6 +127,16 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated): and inner_select.args.get("from") and not any(e.find(exp.AggFunc, exp.Select) for e in inner_select.expressions) and not (leave_tables_isolated and len(outer_scope.selected_sources) > 1) + and not ( + isinstance(from_or_join, exp.Join) + and inner_select.args.get("where") + and from_or_join.side in {"FULL", "LEFT", "RIGHT"} + ) + and not ( + isinstance(from_or_join, exp.From) + and inner_select.args.get("where") + and any(j.side in {"FULL", "RIGHT"} for j in outer_scope.expression.args.get("joins", [])) + ) ) @@ -170,6 +184,12 @@ def _merge_from(outer_scope, inner_scope, node_to_replace, alias): """ new_subquery = inner_scope.expression.args.get("from").expressions[0] node_to_replace.replace(new_subquery) + for join_hint in outer_scope.join_hints: + tables = join_hint.find_all(exp.Table) + for table in tables: + if table.alias_or_name == node_to_replace.alias_or_name: + new_table = new_subquery.this if isinstance(new_subquery, exp.Alias) else new_subquery + table.set("this", exp.to_identifier(new_table.alias_or_name)) outer_scope.remove_source(alias) outer_scope.add_source(new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name]) @@ -273,6 +293,18 @@ def _merge_order(outer_scope, inner_scope): outer_scope.expression.set("order", inner_scope.expression.args.get("order")) +def _merge_hints(outer_scope, inner_scope): + inner_scope_hint = inner_scope.expression.args.get("hint") + if not inner_scope_hint: + return + outer_scope_hint = outer_scope.expression.args.get("hint") + if outer_scope_hint: + for hint_expression in inner_scope_hint.expressions: + outer_scope_hint.append("expressions", hint_expression) + else: + outer_scope.expression.set("hint", inner_scope_hint) + + def _pop_cte(inner_scope): """ Remove CTE from the AST. |