diff options
Diffstat (limited to 'sqlglot/optimizer/unnest_subqueries.py')
-rw-r--r-- | sqlglot/optimizer/unnest_subqueries.py | 24 |
1 files changed, 14 insertions, 10 deletions
diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py index 09e3f2a..816f5fb 100644 --- a/sqlglot/optimizer/unnest_subqueries.py +++ b/sqlglot/optimizer/unnest_subqueries.py @@ -46,20 +46,24 @@ def unnest(select, parent_select, next_alias_name): if not predicate or parent_select is not predicate.parent_select: return - # this subquery returns a scalar and can just be converted to a cross join + # This subquery returns a scalar and can just be converted to a cross join if not isinstance(predicate, (exp.In, exp.Any)): - having = predicate.find_ancestor(exp.Having) column = exp.column(select.selects[0].alias_or_name, alias) - if having and having.parent_select is parent_select: + + clause = predicate.find_ancestor(exp.Having, exp.Where, exp.Join) + clause_parent_select = clause.parent_select if clause else None + + if (isinstance(clause, exp.Having) and clause_parent_select is parent_select) or ( + (not clause or clause_parent_select is not parent_select) + and ( + parent_select.args.get("group") + or any(projection.find(exp.AggFunc) for projection in parent_select.selects) + ) + ): column = exp.Max(this=column) - _replace(select.parent, column) - parent_select.join( - select, - join_type="CROSS", - join_alias=alias, - copy=False, - ) + _replace(select.parent, column) + parent_select.join(select, join_type="CROSS", join_alias=alias, copy=False) return if select.find(exp.Limit, exp.Offset): |