Edit on GitHub

sqlglot.optimizer.unnest_subqueries

  1from sqlglot import exp
  2from sqlglot.helper import name_sequence
  3from sqlglot.optimizer.scope import ScopeType, traverse_scope
  4
  5
  6def unnest_subqueries(expression):
  7    """
  8    Rewrite sqlglot AST to convert some predicates with subqueries into joins.
  9
 10    Convert scalar subqueries into cross joins.
 11    Convert correlated or vectorized subqueries into a group by so it is not a many to many left join.
 12
 13    Example:
 14        >>> import sqlglot
 15        >>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ")
 16        >>> unnest_subqueries(expression).sql()
 17        'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE _u_0.a = 1'
 18
 19    Args:
 20        expression (sqlglot.Expression): expression to unnest
 21    Returns:
 22        sqlglot.Expression: unnested expression
 23    """
 24    next_alias_name = name_sequence("_u_")
 25
 26    for scope in traverse_scope(expression):
 27        select = scope.expression
 28        parent = select.parent_select
 29        if not parent:
 30            continue
 31        if scope.external_columns:
 32            decorrelate(select, parent, scope.external_columns, next_alias_name)
 33        elif scope.scope_type == ScopeType.SUBQUERY:
 34            unnest(select, parent, next_alias_name)
 35
 36    return expression
 37
 38
 39def unnest(select, parent_select, next_alias_name):
 40    if len(select.selects) > 1:
 41        return
 42
 43    predicate = select.find_ancestor(exp.Condition)
 44    alias = next_alias_name()
 45
 46    if not predicate or parent_select is not predicate.parent_select:
 47        return
 48
 49    # This subquery returns a scalar and can just be converted to a cross join
 50    if not isinstance(predicate, (exp.In, exp.Any)):
 51        column = exp.column(select.selects[0].alias_or_name, alias)
 52
 53        clause = predicate.find_ancestor(exp.Having, exp.Where, exp.Join)
 54        clause_parent_select = clause.parent_select if clause else None
 55
 56        if (isinstance(clause, exp.Having) and clause_parent_select is parent_select) or (
 57            (not clause or clause_parent_select is not parent_select)
 58            and (
 59                parent_select.args.get("group")
 60                or any(projection.find(exp.AggFunc) for projection in parent_select.selects)
 61            )
 62        ):
 63            column = exp.Max(this=column)
 64
 65        _replace(select.parent, column)
 66        parent_select.join(select, join_type="CROSS", join_alias=alias, copy=False)
 67        return
 68
 69    if select.find(exp.Limit, exp.Offset):
 70        return
 71
 72    if isinstance(predicate, exp.Any):
 73        predicate = predicate.find_ancestor(exp.EQ)
 74
 75        if not predicate or parent_select is not predicate.parent_select:
 76            return
 77
 78    column = _other_operand(predicate)
 79    value = select.selects[0]
 80
 81    on = exp.condition(f'{column} = "{alias}"."{value.alias}"')
 82    _replace(predicate, f"NOT {on.right} IS NULL")
 83
 84    parent_select.join(
 85        select.group_by(value.this, copy=False),
 86        on=on,
 87        join_type="LEFT",
 88        join_alias=alias,
 89        copy=False,
 90    )
 91
 92
 93def decorrelate(select, parent_select, external_columns, next_alias_name):
 94    where = select.args.get("where")
 95
 96    if not where or where.find(exp.Or) or select.find(exp.Limit, exp.Offset):
 97        return
 98
 99    table_alias = next_alias_name()
100    keys = []
101
102    # for all external columns in the where statement, find the relevant predicate
103    # keys to convert it into a join
104    for column in external_columns:
105        if column.find_ancestor(exp.Where) is not where:
106            return
107
108        predicate = column.find_ancestor(exp.Predicate)
109
110        if not predicate or predicate.find_ancestor(exp.Where) is not where:
111            return
112
113        if isinstance(predicate, exp.Binary):
114            key = (
115                predicate.right
116                if any(node is column for node, *_ in predicate.left.walk())
117                else predicate.left
118            )
119        else:
120            return
121
122        keys.append((key, column, predicate))
123
124    if not any(isinstance(predicate, exp.EQ) for *_, predicate in keys):
125        return
126
127    is_subquery_projection = any(
128        node is select.parent for node in parent_select.selects if isinstance(node, exp.Subquery)
129    )
130
131    value = select.selects[0]
132    key_aliases = {}
133    group_by = []
134
135    for key, _, predicate in keys:
136        # if we filter on the value of the subquery, it needs to be unique
137        if key == value.this:
138            key_aliases[key] = value.alias
139            group_by.append(key)
140        else:
141            if key not in key_aliases:
142                key_aliases[key] = next_alias_name()
143            # all predicates that are equalities must also be in the unique
144            # so that we don't do a many to many join
145            if isinstance(predicate, exp.EQ) and key not in group_by:
146                group_by.append(key)
147
148    parent_predicate = select.find_ancestor(exp.Predicate)
149
150    # if the value of the subquery is not an agg or a key, we need to collect it into an array
151    # so that it can be grouped. For subquery projections, we use a MAX aggregation instead.
152    agg_func = exp.Max if is_subquery_projection else exp.ArrayAgg
153    if not value.find(exp.AggFunc) and value.this not in group_by:
154        select.select(
155            exp.alias_(agg_func(this=value.this), value.alias, quoted=False),
156            append=False,
157            copy=False,
158        )
159
160    # exists queries should not have any selects as it only checks if there are any rows
161    # all selects will be added by the optimizer and only used for join keys
162    if isinstance(parent_predicate, exp.Exists):
163        select.args["expressions"] = []
164
165    for key, alias in key_aliases.items():
166        if key in group_by:
167            # add all keys to the projections of the subquery
168            # so that we can use it as a join key
169            if isinstance(parent_predicate, exp.Exists) or key != value.this:
170                select.select(f"{key} AS {alias}", copy=False)
171        else:
172            select.select(exp.alias_(agg_func(this=key.copy()), alias, quoted=False), copy=False)
173
174    alias = exp.column(value.alias, table_alias)
175    other = _other_operand(parent_predicate)
176
177    if isinstance(parent_predicate, exp.Exists):
178        alias = exp.column(list(key_aliases.values())[0], table_alias)
179        parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL")
180    elif isinstance(parent_predicate, exp.All):
181        parent_predicate = _replace(
182            parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})"
183        )
184    elif isinstance(parent_predicate, exp.Any):
185        if value.this in group_by:
186            parent_predicate = _replace(parent_predicate.parent, f"{other} = {alias}")
187        else:
188            parent_predicate = _replace(parent_predicate, f"ARRAY_ANY({alias}, _x -> _x = {other})")
189    elif isinstance(parent_predicate, exp.In):
190        if value.this in group_by:
191            parent_predicate = _replace(parent_predicate, f"{other} = {alias}")
192        else:
193            parent_predicate = _replace(
194                parent_predicate,
195                f"ARRAY_ANY({alias}, _x -> _x = {parent_predicate.this})",
196            )
197    else:
198        if is_subquery_projection:
199            alias = exp.alias_(alias, select.parent.alias)
200
201        # COUNT always returns 0 on empty datasets, so we need take that into consideration here
202        # by transforming all counts into 0 and using that as the coalesced value
203        if value.find(exp.Count):
204
205            def remove_aggs(node):
206                if isinstance(node, exp.Count):
207                    return exp.Literal.number(0)
208                elif isinstance(node, exp.AggFunc):
209                    return exp.null()
210                return node
211
212            alias = exp.Coalesce(
213                this=alias,
214                expressions=[value.this.transform(remove_aggs)],
215            )
216
217        select.parent.replace(alias)
218
219    for key, column, predicate in keys:
220        predicate.replace(exp.true())
221        nested = exp.column(key_aliases[key], table_alias)
222
223        if is_subquery_projection:
224            key.replace(nested)
225            continue
226
227        if key in group_by:
228            key.replace(nested)
229        elif isinstance(predicate, exp.EQ):
230            parent_predicate = _replace(
231                parent_predicate,
232                f"({parent_predicate} AND ARRAY_CONTAINS({nested}, {column}))",
233            )
234        else:
235            key.replace(exp.to_identifier("_x"))
236            parent_predicate = _replace(
237                parent_predicate,
238                f'({parent_predicate} AND ARRAY_ANY({nested}, "_x" -> {predicate}))',
239            )
240
241    parent_select.join(
242        select.group_by(*group_by, copy=False),
243        on=[predicate for *_, predicate in keys if isinstance(predicate, exp.EQ)],
244        join_type="LEFT",
245        join_alias=table_alias,
246        copy=False,
247    )
248
249
250def _replace(expression, condition):
251    return expression.replace(exp.condition(condition))
252
253
254def _other_operand(expression):
255    if isinstance(expression, exp.In):
256        return expression.this
257
258    if isinstance(expression, (exp.Any, exp.All)):
259        return _other_operand(expression.parent)
260
261    if isinstance(expression, exp.Binary):
262        return (
263            expression.right
264            if isinstance(expression.left, (exp.Subquery, exp.Any, exp.Exists, exp.All))
265            else expression.left
266        )
267
268    return None
def unnest_subqueries(expression):
 7def unnest_subqueries(expression):
 8    """
 9    Rewrite sqlglot AST to convert some predicates with subqueries into joins.
10
11    Convert scalar subqueries into cross joins.
12    Convert correlated or vectorized subqueries into a group by so it is not a many to many left join.
13
14    Example:
15        >>> import sqlglot
16        >>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ")
17        >>> unnest_subqueries(expression).sql()
18        'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE _u_0.a = 1'
19
20    Args:
21        expression (sqlglot.Expression): expression to unnest
22    Returns:
23        sqlglot.Expression: unnested expression
24    """
25    next_alias_name = name_sequence("_u_")
26
27    for scope in traverse_scope(expression):
28        select = scope.expression
29        parent = select.parent_select
30        if not parent:
31            continue
32        if scope.external_columns:
33            decorrelate(select, parent, scope.external_columns, next_alias_name)
34        elif scope.scope_type == ScopeType.SUBQUERY:
35            unnest(select, parent, next_alias_name)
36
37    return expression

Rewrite sqlglot AST to convert some predicates with subqueries into joins.

Convert scalar subqueries into cross joins. Convert correlated or vectorized subqueries into a group by so it is not a many to many left join.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ")
>>> unnest_subqueries(expression).sql()
'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE _u_0.a = 1'
Arguments:
  • expression (sqlglot.Expression): expression to unnest
Returns:

sqlglot.Expression: unnested expression

def unnest(select, parent_select, next_alias_name):
40def unnest(select, parent_select, next_alias_name):
41    if len(select.selects) > 1:
42        return
43
44    predicate = select.find_ancestor(exp.Condition)
45    alias = next_alias_name()
46
47    if not predicate or parent_select is not predicate.parent_select:
48        return
49
50    # This subquery returns a scalar and can just be converted to a cross join
51    if not isinstance(predicate, (exp.In, exp.Any)):
52        column = exp.column(select.selects[0].alias_or_name, alias)
53
54        clause = predicate.find_ancestor(exp.Having, exp.Where, exp.Join)
55        clause_parent_select = clause.parent_select if clause else None
56
57        if (isinstance(clause, exp.Having) and clause_parent_select is parent_select) or (
58            (not clause or clause_parent_select is not parent_select)
59            and (
60                parent_select.args.get("group")
61                or any(projection.find(exp.AggFunc) for projection in parent_select.selects)
62            )
63        ):
64            column = exp.Max(this=column)
65
66        _replace(select.parent, column)
67        parent_select.join(select, join_type="CROSS", join_alias=alias, copy=False)
68        return
69
70    if select.find(exp.Limit, exp.Offset):
71        return
72
73    if isinstance(predicate, exp.Any):
74        predicate = predicate.find_ancestor(exp.EQ)
75
76        if not predicate or parent_select is not predicate.parent_select:
77            return
78
79    column = _other_operand(predicate)
80    value = select.selects[0]
81
82    on = exp.condition(f'{column} = "{alias}"."{value.alias}"')
83    _replace(predicate, f"NOT {on.right} IS NULL")
84
85    parent_select.join(
86        select.group_by(value.this, copy=False),
87        on=on,
88        join_type="LEFT",
89        join_alias=alias,
90        copy=False,
91    )
def decorrelate(select, parent_select, external_columns, next_alias_name):
 94def decorrelate(select, parent_select, external_columns, next_alias_name):
 95    where = select.args.get("where")
 96
 97    if not where or where.find(exp.Or) or select.find(exp.Limit, exp.Offset):
 98        return
 99
100    table_alias = next_alias_name()
101    keys = []
102
103    # for all external columns in the where statement, find the relevant predicate
104    # keys to convert it into a join
105    for column in external_columns:
106        if column.find_ancestor(exp.Where) is not where:
107            return
108
109        predicate = column.find_ancestor(exp.Predicate)
110
111        if not predicate or predicate.find_ancestor(exp.Where) is not where:
112            return
113
114        if isinstance(predicate, exp.Binary):
115            key = (
116                predicate.right
117                if any(node is column for node, *_ in predicate.left.walk())
118                else predicate.left
119            )
120        else:
121            return
122
123        keys.append((key, column, predicate))
124
125    if not any(isinstance(predicate, exp.EQ) for *_, predicate in keys):
126        return
127
128    is_subquery_projection = any(
129        node is select.parent for node in parent_select.selects if isinstance(node, exp.Subquery)
130    )
131
132    value = select.selects[0]
133    key_aliases = {}
134    group_by = []
135
136    for key, _, predicate in keys:
137        # if we filter on the value of the subquery, it needs to be unique
138        if key == value.this:
139            key_aliases[key] = value.alias
140            group_by.append(key)
141        else:
142            if key not in key_aliases:
143                key_aliases[key] = next_alias_name()
144            # all predicates that are equalities must also be in the unique
145            # so that we don't do a many to many join
146            if isinstance(predicate, exp.EQ) and key not in group_by:
147                group_by.append(key)
148
149    parent_predicate = select.find_ancestor(exp.Predicate)
150
151    # if the value of the subquery is not an agg or a key, we need to collect it into an array
152    # so that it can be grouped. For subquery projections, we use a MAX aggregation instead.
153    agg_func = exp.Max if is_subquery_projection else exp.ArrayAgg
154    if not value.find(exp.AggFunc) and value.this not in group_by:
155        select.select(
156            exp.alias_(agg_func(this=value.this), value.alias, quoted=False),
157            append=False,
158            copy=False,
159        )
160
161    # exists queries should not have any selects as it only checks if there are any rows
162    # all selects will be added by the optimizer and only used for join keys
163    if isinstance(parent_predicate, exp.Exists):
164        select.args["expressions"] = []
165
166    for key, alias in key_aliases.items():
167        if key in group_by:
168            # add all keys to the projections of the subquery
169            # so that we can use it as a join key
170            if isinstance(parent_predicate, exp.Exists) or key != value.this:
171                select.select(f"{key} AS {alias}", copy=False)
172        else:
173            select.select(exp.alias_(agg_func(this=key.copy()), alias, quoted=False), copy=False)
174
175    alias = exp.column(value.alias, table_alias)
176    other = _other_operand(parent_predicate)
177
178    if isinstance(parent_predicate, exp.Exists):
179        alias = exp.column(list(key_aliases.values())[0], table_alias)
180        parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL")
181    elif isinstance(parent_predicate, exp.All):
182        parent_predicate = _replace(
183            parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})"
184        )
185    elif isinstance(parent_predicate, exp.Any):
186        if value.this in group_by:
187            parent_predicate = _replace(parent_predicate.parent, f"{other} = {alias}")
188        else:
189            parent_predicate = _replace(parent_predicate, f"ARRAY_ANY({alias}, _x -> _x = {other})")
190    elif isinstance(parent_predicate, exp.In):
191        if value.this in group_by:
192            parent_predicate = _replace(parent_predicate, f"{other} = {alias}")
193        else:
194            parent_predicate = _replace(
195                parent_predicate,
196                f"ARRAY_ANY({alias}, _x -> _x = {parent_predicate.this})",
197            )
198    else:
199        if is_subquery_projection:
200            alias = exp.alias_(alias, select.parent.alias)
201
202        # COUNT always returns 0 on empty datasets, so we need take that into consideration here
203        # by transforming all counts into 0 and using that as the coalesced value
204        if value.find(exp.Count):
205
206            def remove_aggs(node):
207                if isinstance(node, exp.Count):
208                    return exp.Literal.number(0)
209                elif isinstance(node, exp.AggFunc):
210                    return exp.null()
211                return node
212
213            alias = exp.Coalesce(
214                this=alias,
215                expressions=[value.this.transform(remove_aggs)],
216            )
217
218        select.parent.replace(alias)
219
220    for key, column, predicate in keys:
221        predicate.replace(exp.true())
222        nested = exp.column(key_aliases[key], table_alias)
223
224        if is_subquery_projection:
225            key.replace(nested)
226            continue
227
228        if key in group_by:
229            key.replace(nested)
230        elif isinstance(predicate, exp.EQ):
231            parent_predicate = _replace(
232                parent_predicate,
233                f"({parent_predicate} AND ARRAY_CONTAINS({nested}, {column}))",
234            )
235        else:
236            key.replace(exp.to_identifier("_x"))
237            parent_predicate = _replace(
238                parent_predicate,
239                f'({parent_predicate} AND ARRAY_ANY({nested}, "_x" -> {predicate}))',
240            )
241
242    parent_select.join(
243        select.group_by(*group_by, copy=False),
244        on=[predicate for *_, predicate in keys if isinstance(predicate, exp.EQ)],
245        join_type="LEFT",
246        join_alias=table_alias,
247        copy=False,
248    )