summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/unnest_subqueries.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer/unnest_subqueries.py')
-rw-r--r--sqlglot/optimizer/unnest_subqueries.py27
1 files changed, 21 insertions, 6 deletions
diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py
index 2046917..8d78294 100644
--- a/sqlglot/optimizer/unnest_subqueries.py
+++ b/sqlglot/optimizer/unnest_subqueries.py
@@ -16,7 +16,7 @@ def unnest_subqueries(expression):
>>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ")
>>> unnest_subqueries(expression).sql()
'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a)\
- AS "_u_0" ON x.a = "_u_0".a WHERE ("_u_0".a = 1 AND NOT "_u_0".a IS NULL)'
+ AS _u_0 ON x.a = _u_0.a WHERE (_u_0.a = 1 AND NOT _u_0.a IS NULL)'
Args:
expression (sqlglot.Expression): expression to unnest
@@ -97,8 +97,8 @@ def decorrelate(select, parent_select, external_columns, sequence):
table_alias = _alias(sequence)
keys = []
- # for all external columns in the where statement,
- # split out the relevant data to convert it into a join
+ # for all external columns in the where statement, find the relevant predicate
+ # keys to convert it into a join
for column in external_columns:
if column.find_ancestor(exp.Where) is not where:
return
@@ -122,6 +122,10 @@ def decorrelate(select, parent_select, external_columns, sequence):
if not any(isinstance(predicate, exp.EQ) for *_, predicate in keys):
return
+ is_subquery_projection = any(
+ node is select.parent for node in parent_select.selects if isinstance(node, exp.Subquery)
+ )
+
value = select.selects[0]
key_aliases = {}
group_by = []
@@ -142,9 +146,14 @@ def decorrelate(select, parent_select, external_columns, sequence):
parent_predicate = select.find_ancestor(exp.Predicate)
# if the value of the subquery is not an agg or a key, we need to collect it into an array
- # so that it can be grouped
+ # so that it can be grouped. For subquery projections, we use a MAX aggregation instead.
+ agg_func = exp.Max if is_subquery_projection else exp.ArrayAgg
if not value.find(exp.AggFunc) and value.this not in group_by:
- select.select(f"ARRAY_AGG({value.this}) AS {value.alias}", append=False, copy=False)
+ select.select(
+ exp.alias_(agg_func(this=value.this), value.alias, quoted=False),
+ append=False,
+ copy=False,
+ )
# exists queries should not have any selects as it only checks if there are any rows
# all selects will be added by the optimizer and only used for join keys
@@ -158,7 +167,7 @@ def decorrelate(select, parent_select, external_columns, sequence):
if isinstance(parent_predicate, exp.Exists) or key != value.this:
select.select(f"{key} AS {alias}", copy=False)
else:
- select.select(f"ARRAY_AGG({key}) AS {alias}", copy=False)
+ select.select(exp.alias_(agg_func(this=key.copy()), alias, quoted=False), copy=False)
alias = exp.column(value.alias, table_alias)
other = _other_operand(parent_predicate)
@@ -186,12 +195,18 @@ def decorrelate(select, parent_select, external_columns, sequence):
f"ARRAY_ANY({alias}, _x -> _x = {parent_predicate.this})",
)
else:
+ if is_subquery_projection:
+ alias = exp.alias_(alias, select.parent.alias)
select.parent.replace(alias)
for key, column, predicate in keys:
predicate.replace(exp.true())
nested = exp.column(key_aliases[key], table_alias)
+ if is_subquery_projection:
+ key.replace(nested)
+ continue
+
if key in group_by:
key.replace(nested)
parent_predicate = _replace(