summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/merge_subqueries.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer/merge_subqueries.py')
-rw-r--r--sqlglot/optimizer/merge_subqueries.py44
1 files changed, 38 insertions, 6 deletions
diff --git a/sqlglot/optimizer/merge_subqueries.py b/sqlglot/optimizer/merge_subqueries.py
index d29c22b..3e435f5 100644
--- a/sqlglot/optimizer/merge_subqueries.py
+++ b/sqlglot/optimizer/merge_subqueries.py
@@ -44,6 +44,7 @@ UNMERGABLE_ARGS = set(exp.Select.arg_types) - {
"joins",
"where",
"order",
+ "hint",
}
@@ -67,21 +68,22 @@ def merge_ctes(expression, leave_tables_isolated=False):
singular_cte_selections = [v[0] for k, v in cte_selections.items() if len(v) == 1]
for outer_scope, inner_scope, table in singular_cte_selections:
inner_select = inner_scope.expression.unnest()
- if _mergeable(outer_scope, inner_select, leave_tables_isolated):
- from_or_join = table.find_ancestor(exp.From, exp.Join)
-
+ from_or_join = table.find_ancestor(exp.From, exp.Join)
+ if _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
node_to_replace = table
if isinstance(node_to_replace.parent, exp.Alias):
node_to_replace = node_to_replace.parent
alias = node_to_replace.alias
else:
alias = table.name
+
_rename_inner_sources(outer_scope, inner_scope, alias)
_merge_from(outer_scope, inner_scope, node_to_replace, alias)
_merge_expressions(outer_scope, inner_scope, alias)
_merge_joins(outer_scope, inner_scope, from_or_join)
_merge_where(outer_scope, inner_scope, from_or_join)
_merge_order(outer_scope, inner_scope)
+ _merge_hints(outer_scope, inner_scope)
_pop_cte(inner_scope)
return expression
@@ -90,9 +92,9 @@ def merge_derived_tables(expression, leave_tables_isolated=False):
for outer_scope in traverse_scope(expression):
for subquery in outer_scope.derived_tables:
inner_select = subquery.unnest()
- if _mergeable(outer_scope, inner_select, leave_tables_isolated):
+ from_or_join = subquery.find_ancestor(exp.From, exp.Join)
+ if _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
alias = subquery.alias_or_name
- from_or_join = subquery.find_ancestor(exp.From, exp.Join)
inner_scope = outer_scope.sources[alias]
_rename_inner_sources(outer_scope, inner_scope, alias)
@@ -101,10 +103,11 @@ def merge_derived_tables(expression, leave_tables_isolated=False):
_merge_joins(outer_scope, inner_scope, from_or_join)
_merge_where(outer_scope, inner_scope, from_or_join)
_merge_order(outer_scope, inner_scope)
+ _merge_hints(outer_scope, inner_scope)
return expression
-def _mergeable(outer_scope, inner_select, leave_tables_isolated):
+def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
"""
Return True if `inner_select` can be merged into outer query.
@@ -112,6 +115,7 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated):
outer_scope (Scope)
inner_select (exp.Select)
leave_tables_isolated (bool)
+ from_or_join (exp.From|exp.Join)
Returns:
bool: True if can be merged
"""
@@ -123,6 +127,16 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated):
and inner_select.args.get("from")
and not any(e.find(exp.AggFunc, exp.Select) for e in inner_select.expressions)
and not (leave_tables_isolated and len(outer_scope.selected_sources) > 1)
+ and not (
+ isinstance(from_or_join, exp.Join)
+ and inner_select.args.get("where")
+ and from_or_join.side in {"FULL", "LEFT", "RIGHT"}
+ )
+ and not (
+ isinstance(from_or_join, exp.From)
+ and inner_select.args.get("where")
+ and any(j.side in {"FULL", "RIGHT"} for j in outer_scope.expression.args.get("joins", []))
+ )
)
@@ -170,6 +184,12 @@ def _merge_from(outer_scope, inner_scope, node_to_replace, alias):
"""
new_subquery = inner_scope.expression.args.get("from").expressions[0]
node_to_replace.replace(new_subquery)
+ for join_hint in outer_scope.join_hints:
+ tables = join_hint.find_all(exp.Table)
+ for table in tables:
+ if table.alias_or_name == node_to_replace.alias_or_name:
+ new_table = new_subquery.this if isinstance(new_subquery, exp.Alias) else new_subquery
+ table.set("this", exp.to_identifier(new_table.alias_or_name))
outer_scope.remove_source(alias)
outer_scope.add_source(new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name])
@@ -273,6 +293,18 @@ def _merge_order(outer_scope, inner_scope):
outer_scope.expression.set("order", inner_scope.expression.args.get("order"))
+def _merge_hints(outer_scope, inner_scope):
+ inner_scope_hint = inner_scope.expression.args.get("hint")
+ if not inner_scope_hint:
+ return
+ outer_scope_hint = outer_scope.expression.args.get("hint")
+ if outer_scope_hint:
+ for hint_expression in inner_scope_hint.expressions:
+ outer_scope_hint.append("expressions", hint_expression)
+ else:
+ outer_scope.expression.set("hint", inner_scope_hint)
+
+
def _pop_cte(inner_scope):
"""
Remove CTE from the AST.