summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/unnest_subqueries.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer/unnest_subqueries.py')
-rw-r--r--sqlglot/optimizer/unnest_subqueries.py24
1 files changed, 14 insertions, 10 deletions
diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py
index 09e3f2a..816f5fb 100644
--- a/sqlglot/optimizer/unnest_subqueries.py
+++ b/sqlglot/optimizer/unnest_subqueries.py
@@ -46,20 +46,24 @@ def unnest(select, parent_select, next_alias_name):
if not predicate or parent_select is not predicate.parent_select:
return
- # this subquery returns a scalar and can just be converted to a cross join
+ # This subquery returns a scalar and can just be converted to a cross join
if not isinstance(predicate, (exp.In, exp.Any)):
- having = predicate.find_ancestor(exp.Having)
column = exp.column(select.selects[0].alias_or_name, alias)
- if having and having.parent_select is parent_select:
+
+ clause = predicate.find_ancestor(exp.Having, exp.Where, exp.Join)
+ clause_parent_select = clause.parent_select if clause else None
+
+ if (isinstance(clause, exp.Having) and clause_parent_select is parent_select) or (
+ (not clause or clause_parent_select is not parent_select)
+ and (
+ parent_select.args.get("group")
+ or any(projection.find(exp.AggFunc) for projection in parent_select.selects)
+ )
+ ):
column = exp.Max(this=column)
- _replace(select.parent, column)
- parent_select.join(
- select,
- join_type="CROSS",
- join_alias=alias,
- copy=False,
- )
+ _replace(select.parent, column)
+ parent_select.join(select, join_type="CROSS", join_alias=alias, copy=False)
return
if select.find(exp.Limit, exp.Offset):