summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/unnest_subqueries.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer/unnest_subqueries.py')
-rw-r--r--sqlglot/optimizer/unnest_subqueries.py38
1 files changed, 28 insertions, 10 deletions
diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py
index 8d78294..a515489 100644
--- a/sqlglot/optimizer/unnest_subqueries.py
+++ b/sqlglot/optimizer/unnest_subqueries.py
@@ -15,8 +15,7 @@ def unnest_subqueries(expression):
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ")
>>> unnest_subqueries(expression).sql()
- 'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a)\
- AS _u_0 ON x.a = _u_0.a WHERE (_u_0.a = 1 AND NOT _u_0.a IS NULL)'
+ 'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE _u_0.a = 1'
Args:
expression (sqlglot.Expression): expression to unnest
@@ -173,10 +172,8 @@ def decorrelate(select, parent_select, external_columns, sequence):
other = _other_operand(parent_predicate)
if isinstance(parent_predicate, exp.Exists):
- if value.this in group_by:
- parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL")
- else:
- parent_predicate = _replace(parent_predicate, "TRUE")
+ alias = exp.column(list(key_aliases.values())[0], table_alias)
+ parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL")
elif isinstance(parent_predicate, exp.All):
parent_predicate = _replace(
parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})"
@@ -197,6 +194,23 @@ def decorrelate(select, parent_select, external_columns, sequence):
else:
if is_subquery_projection:
alias = exp.alias_(alias, select.parent.alias)
+
+ # COUNT always returns 0 on empty datasets, so we need take that into consideration here
+ # by transforming all counts into 0 and using that as the coalesced value
+ if value.find(exp.Count):
+
+ def remove_aggs(node):
+ if isinstance(node, exp.Count):
+ return exp.Literal.number(0)
+ elif isinstance(node, exp.AggFunc):
+ return exp.null()
+ return node
+
+ alias = exp.Coalesce(
+ this=alias,
+ expressions=[value.this.transform(remove_aggs)],
+ )
+
select.parent.replace(alias)
for key, column, predicate in keys:
@@ -209,9 +223,6 @@ def decorrelate(select, parent_select, external_columns, sequence):
if key in group_by:
key.replace(nested)
- parent_predicate = _replace(
- parent_predicate, f"({parent_predicate} AND NOT {nested} IS NULL)"
- )
elif isinstance(predicate, exp.EQ):
parent_predicate = _replace(
parent_predicate,
@@ -245,7 +256,14 @@ def _other_operand(expression):
if isinstance(expression, exp.In):
return expression.this
+ if isinstance(expression, (exp.Any, exp.All)):
+ return _other_operand(expression.parent)
+
if isinstance(expression, exp.Binary):
- return expression.right if expression.arg_key == "this" else expression.left
+ return (
+ expression.right
+ if isinstance(expression.left, (exp.Subquery, exp.Any, exp.Exists, exp.All))
+ else expression.left
+ )
return None