From 621555af37594a213d91ea113d5fc7739af84d40 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Wed, 4 Jan 2023 08:24:05 +0100 Subject: Adding upstream version 10.4.2. Signed-off-by: Daniel Baumann --- sqlglot/optimizer/unnest_subqueries.py | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) (limited to 'sqlglot/optimizer/unnest_subqueries.py') diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py index 2046917..8d78294 100644 --- a/sqlglot/optimizer/unnest_subqueries.py +++ b/sqlglot/optimizer/unnest_subqueries.py @@ -16,7 +16,7 @@ def unnest_subqueries(expression): >>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ") >>> unnest_subqueries(expression).sql() 'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a)\ - AS "_u_0" ON x.a = "_u_0".a WHERE ("_u_0".a = 1 AND NOT "_u_0".a IS NULL)' + AS _u_0 ON x.a = _u_0.a WHERE (_u_0.a = 1 AND NOT _u_0.a IS NULL)' Args: expression (sqlglot.Expression): expression to unnest @@ -97,8 +97,8 @@ def decorrelate(select, parent_select, external_columns, sequence): table_alias = _alias(sequence) keys = [] - # for all external columns in the where statement, - # split out the relevant data to convert it into a join + # for all external columns in the where statement, find the relevant predicate + # keys to convert it into a join for column in external_columns: if column.find_ancestor(exp.Where) is not where: return @@ -122,6 +122,10 @@ def decorrelate(select, parent_select, external_columns, sequence): if not any(isinstance(predicate, exp.EQ) for *_, predicate in keys): return + is_subquery_projection = any( + node is select.parent for node in parent_select.selects if isinstance(node, exp.Subquery) + ) + value = select.selects[0] key_aliases = {} group_by = [] @@ -142,9 +146,14 @@ def decorrelate(select, parent_select, external_columns, sequence): parent_predicate = select.find_ancestor(exp.Predicate) # if the value of the subquery is not an agg or a key, we need to collect it into an array - # so that it can be grouped + # so that it can be grouped. For subquery projections, we use a MAX aggregation instead. + agg_func = exp.Max if is_subquery_projection else exp.ArrayAgg if not value.find(exp.AggFunc) and value.this not in group_by: - select.select(f"ARRAY_AGG({value.this}) AS {value.alias}", append=False, copy=False) + select.select( + exp.alias_(agg_func(this=value.this), value.alias, quoted=False), + append=False, + copy=False, + ) # exists queries should not have any selects as it only checks if there are any rows # all selects will be added by the optimizer and only used for join keys @@ -158,7 +167,7 @@ def decorrelate(select, parent_select, external_columns, sequence): if isinstance(parent_predicate, exp.Exists) or key != value.this: select.select(f"{key} AS {alias}", copy=False) else: - select.select(f"ARRAY_AGG({key}) AS {alias}", copy=False) + select.select(exp.alias_(agg_func(this=key.copy()), alias, quoted=False), copy=False) alias = exp.column(value.alias, table_alias) other = _other_operand(parent_predicate) @@ -186,12 +195,18 @@ def decorrelate(select, parent_select, external_columns, sequence): f"ARRAY_ANY({alias}, _x -> _x = {parent_predicate.this})", ) else: + if is_subquery_projection: + alias = exp.alias_(alias, select.parent.alias) select.parent.replace(alias) for key, column, predicate in keys: predicate.replace(exp.true()) nested = exp.column(key_aliases[key], table_alias) + if is_subquery_projection: + key.replace(nested) + continue + if key in group_by: key.replace(nested) parent_predicate = _replace( -- cgit v1.2.3