summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/unnest_subqueries.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer/unnest_subqueries.py')
-rw-r--r--sqlglot/optimizer/unnest_subqueries.py36
1 files changed, 28 insertions, 8 deletions
diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py
index dbd680b..2046917 100644
--- a/sqlglot/optimizer/unnest_subqueries.py
+++ b/sqlglot/optimizer/unnest_subqueries.py
@@ -1,16 +1,15 @@
import itertools
from sqlglot import exp
-from sqlglot.optimizer.scope import traverse_scope
+from sqlglot.optimizer.scope import ScopeType, traverse_scope
def unnest_subqueries(expression):
"""
Rewrite sqlglot AST to convert some predicates with subqueries into joins.
- Convert the subquery into a group by so it is not a many to many left join.
- Unnesting can only occur if the subquery does not have LIMIT or OFFSET.
- Unnesting non correlated subqueries only happens on IN statements or = ANY statements.
+ Convert scalar subqueries into cross joins.
+ Convert correlated or vectorized subqueries into a group by so it is not a many to many left join.
Example:
>>> import sqlglot
@@ -29,21 +28,43 @@ def unnest_subqueries(expression):
for scope in traverse_scope(expression):
select = scope.expression
parent = select.parent_select
+ if not parent:
+ continue
if scope.external_columns:
decorrelate(select, parent, scope.external_columns, sequence)
- else:
+ elif scope.scope_type == ScopeType.SUBQUERY:
unnest(select, parent, sequence)
return expression
def unnest(select, parent_select, sequence):
- predicate = select.find_ancestor(exp.In, exp.Any)
+ if len(select.selects) > 1:
+ return
+
+ predicate = select.find_ancestor(exp.Condition)
+ alias = _alias(sequence)
if not predicate or parent_select is not predicate.parent_select:
return
- if len(select.selects) > 1 or select.find(exp.Limit, exp.Offset):
+ # this subquery returns a scalar and can just be converted to a cross join
+ if not isinstance(predicate, (exp.In, exp.Any)):
+ having = predicate.find_ancestor(exp.Having)
+ column = exp.column(select.selects[0].alias_or_name, alias)
+ if having and having.parent_select is parent_select:
+ column = exp.Max(this=column)
+ _replace(select.parent, column)
+
+ parent_select.join(
+ select,
+ join_type="CROSS",
+ join_alias=alias,
+ copy=False,
+ )
+ return
+
+ if select.find(exp.Limit, exp.Offset):
return
if isinstance(predicate, exp.Any):
@@ -54,7 +75,6 @@ def unnest(select, parent_select, sequence):
column = _other_operand(predicate)
value = select.selects[0]
- alias = _alias(sequence)
on = exp.condition(f'{column} = "{alias}"."{value.alias}"')
_replace(predicate, f"NOT {on.right} IS NULL")