Edit on GitHub

sqlglot.optimizer.unnest_subqueries

  1import itertools
  2
  3from sqlglot import exp
  4from sqlglot.optimizer.scope import ScopeType, traverse_scope
  5
  6
  7def unnest_subqueries(expression):
  8    """
  9    Rewrite sqlglot AST to convert some predicates with subqueries into joins.
 10
 11    Convert scalar subqueries into cross joins.
 12    Convert correlated or vectorized subqueries into a group by so it is not a many to many left join.
 13
 14    Example:
 15        >>> import sqlglot
 16        >>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ")
 17        >>> unnest_subqueries(expression).sql()
 18        'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE _u_0.a = 1'
 19
 20    Args:
 21        expression (sqlglot.Expression): expression to unnest
 22    Returns:
 23        sqlglot.Expression: unnested expression
 24    """
 25    sequence = itertools.count()
 26
 27    for scope in traverse_scope(expression):
 28        select = scope.expression
 29        parent = select.parent_select
 30        if not parent:
 31            continue
 32        if scope.external_columns:
 33            decorrelate(select, parent, scope.external_columns, sequence)
 34        elif scope.scope_type == ScopeType.SUBQUERY:
 35            unnest(select, parent, sequence)
 36
 37    return expression
 38
 39
 40def unnest(select, parent_select, sequence):
 41    if len(select.selects) > 1:
 42        return
 43
 44    predicate = select.find_ancestor(exp.Condition)
 45    alias = _alias(sequence)
 46
 47    if not predicate or parent_select is not predicate.parent_select:
 48        return
 49
 50    # this subquery returns a scalar and can just be converted to a cross join
 51    if not isinstance(predicate, (exp.In, exp.Any)):
 52        having = predicate.find_ancestor(exp.Having)
 53        column = exp.column(select.selects[0].alias_or_name, alias)
 54        if having and having.parent_select is parent_select:
 55            column = exp.Max(this=column)
 56        _replace(select.parent, column)
 57
 58        parent_select.join(
 59            select,
 60            join_type="CROSS",
 61            join_alias=alias,
 62            copy=False,
 63        )
 64        return
 65
 66    if select.find(exp.Limit, exp.Offset):
 67        return
 68
 69    if isinstance(predicate, exp.Any):
 70        predicate = predicate.find_ancestor(exp.EQ)
 71
 72        if not predicate or parent_select is not predicate.parent_select:
 73            return
 74
 75    column = _other_operand(predicate)
 76    value = select.selects[0]
 77
 78    on = exp.condition(f'{column} = "{alias}"."{value.alias}"')
 79    _replace(predicate, f"NOT {on.right} IS NULL")
 80
 81    parent_select.join(
 82        select.group_by(value.this, copy=False),
 83        on=on,
 84        join_type="LEFT",
 85        join_alias=alias,
 86        copy=False,
 87    )
 88
 89
 90def decorrelate(select, parent_select, external_columns, sequence):
 91    where = select.args.get("where")
 92
 93    if not where or where.find(exp.Or) or select.find(exp.Limit, exp.Offset):
 94        return
 95
 96    table_alias = _alias(sequence)
 97    keys = []
 98
 99    # for all external columns in the where statement, find the relevant predicate
100    # keys to convert it into a join
101    for column in external_columns:
102        if column.find_ancestor(exp.Where) is not where:
103            return
104
105        predicate = column.find_ancestor(exp.Predicate)
106
107        if not predicate or predicate.find_ancestor(exp.Where) is not where:
108            return
109
110        if isinstance(predicate, exp.Binary):
111            key = (
112                predicate.right
113                if any(node is column for node, *_ in predicate.left.walk())
114                else predicate.left
115            )
116        else:
117            return
118
119        keys.append((key, column, predicate))
120
121    if not any(isinstance(predicate, exp.EQ) for *_, predicate in keys):
122        return
123
124    is_subquery_projection = any(
125        node is select.parent for node in parent_select.selects if isinstance(node, exp.Subquery)
126    )
127
128    value = select.selects[0]
129    key_aliases = {}
130    group_by = []
131
132    for key, _, predicate in keys:
133        # if we filter on the value of the subquery, it needs to be unique
134        if key == value.this:
135            key_aliases[key] = value.alias
136            group_by.append(key)
137        else:
138            if key not in key_aliases:
139                key_aliases[key] = _alias(sequence)
140            # all predicates that are equalities must also be in the unique
141            # so that we don't do a many to many join
142            if isinstance(predicate, exp.EQ) and key not in group_by:
143                group_by.append(key)
144
145    parent_predicate = select.find_ancestor(exp.Predicate)
146
147    # if the value of the subquery is not an agg or a key, we need to collect it into an array
148    # so that it can be grouped. For subquery projections, we use a MAX aggregation instead.
149    agg_func = exp.Max if is_subquery_projection else exp.ArrayAgg
150    if not value.find(exp.AggFunc) and value.this not in group_by:
151        select.select(
152            exp.alias_(agg_func(this=value.this), value.alias, quoted=False),
153            append=False,
154            copy=False,
155        )
156
157    # exists queries should not have any selects as it only checks if there are any rows
158    # all selects will be added by the optimizer and only used for join keys
159    if isinstance(parent_predicate, exp.Exists):
160        select.args["expressions"] = []
161
162    for key, alias in key_aliases.items():
163        if key in group_by:
164            # add all keys to the projections of the subquery
165            # so that we can use it as a join key
166            if isinstance(parent_predicate, exp.Exists) or key != value.this:
167                select.select(f"{key} AS {alias}", copy=False)
168        else:
169            select.select(exp.alias_(agg_func(this=key.copy()), alias, quoted=False), copy=False)
170
171    alias = exp.column(value.alias, table_alias)
172    other = _other_operand(parent_predicate)
173
174    if isinstance(parent_predicate, exp.Exists):
175        alias = exp.column(list(key_aliases.values())[0], table_alias)
176        parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL")
177    elif isinstance(parent_predicate, exp.All):
178        parent_predicate = _replace(
179            parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})"
180        )
181    elif isinstance(parent_predicate, exp.Any):
182        if value.this in group_by:
183            parent_predicate = _replace(parent_predicate.parent, f"{other} = {alias}")
184        else:
185            parent_predicate = _replace(parent_predicate, f"ARRAY_ANY({alias}, _x -> _x = {other})")
186    elif isinstance(parent_predicate, exp.In):
187        if value.this in group_by:
188            parent_predicate = _replace(parent_predicate, f"{other} = {alias}")
189        else:
190            parent_predicate = _replace(
191                parent_predicate,
192                f"ARRAY_ANY({alias}, _x -> _x = {parent_predicate.this})",
193            )
194    else:
195        if is_subquery_projection:
196            alias = exp.alias_(alias, select.parent.alias)
197
198        # COUNT always returns 0 on empty datasets, so we need take that into consideration here
199        # by transforming all counts into 0 and using that as the coalesced value
200        if value.find(exp.Count):
201
202            def remove_aggs(node):
203                if isinstance(node, exp.Count):
204                    return exp.Literal.number(0)
205                elif isinstance(node, exp.AggFunc):
206                    return exp.null()
207                return node
208
209            alias = exp.Coalesce(
210                this=alias,
211                expressions=[value.this.transform(remove_aggs)],
212            )
213
214        select.parent.replace(alias)
215
216    for key, column, predicate in keys:
217        predicate.replace(exp.true())
218        nested = exp.column(key_aliases[key], table_alias)
219
220        if is_subquery_projection:
221            key.replace(nested)
222            continue
223
224        if key in group_by:
225            key.replace(nested)
226        elif isinstance(predicate, exp.EQ):
227            parent_predicate = _replace(
228                parent_predicate,
229                f"({parent_predicate} AND ARRAY_CONTAINS({nested}, {column}))",
230            )
231        else:
232            key.replace(exp.to_identifier("_x"))
233            parent_predicate = _replace(
234                parent_predicate,
235                f'({parent_predicate} AND ARRAY_ANY({nested}, "_x" -> {predicate}))',
236            )
237
238    parent_select.join(
239        select.group_by(*group_by, copy=False),
240        on=[predicate for *_, predicate in keys if isinstance(predicate, exp.EQ)],
241        join_type="LEFT",
242        join_alias=table_alias,
243        copy=False,
244    )
245
246
247def _alias(sequence):
248    return f"_u_{next(sequence)}"
249
250
251def _replace(expression, condition):
252    return expression.replace(exp.condition(condition))
253
254
255def _other_operand(expression):
256    if isinstance(expression, exp.In):
257        return expression.this
258
259    if isinstance(expression, (exp.Any, exp.All)):
260        return _other_operand(expression.parent)
261
262    if isinstance(expression, exp.Binary):
263        return (
264            expression.right
265            if isinstance(expression.left, (exp.Subquery, exp.Any, exp.Exists, exp.All))
266            else expression.left
267        )
268
269    return None
def unnest_subqueries(expression):
 8def unnest_subqueries(expression):
 9    """
10    Rewrite sqlglot AST to convert some predicates with subqueries into joins.
11
12    Convert scalar subqueries into cross joins.
13    Convert correlated or vectorized subqueries into a group by so it is not a many to many left join.
14
15    Example:
16        >>> import sqlglot
17        >>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ")
18        >>> unnest_subqueries(expression).sql()
19        'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE _u_0.a = 1'
20
21    Args:
22        expression (sqlglot.Expression): expression to unnest
23    Returns:
24        sqlglot.Expression: unnested expression
25    """
26    sequence = itertools.count()
27
28    for scope in traverse_scope(expression):
29        select = scope.expression
30        parent = select.parent_select
31        if not parent:
32            continue
33        if scope.external_columns:
34            decorrelate(select, parent, scope.external_columns, sequence)
35        elif scope.scope_type == ScopeType.SUBQUERY:
36            unnest(select, parent, sequence)
37
38    return expression

Rewrite sqlglot AST to convert some predicates with subqueries into joins.

Convert scalar subqueries into cross joins. Convert correlated or vectorized subqueries into a group by so it is not a many to many left join.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ")
>>> unnest_subqueries(expression).sql()
'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE _u_0.a = 1'
Arguments:
  • expression (sqlglot.Expression): expression to unnest
Returns:

sqlglot.Expression: unnested expression

def unnest(select, parent_select, sequence):
41def unnest(select, parent_select, sequence):
42    if len(select.selects) > 1:
43        return
44
45    predicate = select.find_ancestor(exp.Condition)
46    alias = _alias(sequence)
47
48    if not predicate or parent_select is not predicate.parent_select:
49        return
50
51    # this subquery returns a scalar and can just be converted to a cross join
52    if not isinstance(predicate, (exp.In, exp.Any)):
53        having = predicate.find_ancestor(exp.Having)
54        column = exp.column(select.selects[0].alias_or_name, alias)
55        if having and having.parent_select is parent_select:
56            column = exp.Max(this=column)
57        _replace(select.parent, column)
58
59        parent_select.join(
60            select,
61            join_type="CROSS",
62            join_alias=alias,
63            copy=False,
64        )
65        return
66
67    if select.find(exp.Limit, exp.Offset):
68        return
69
70    if isinstance(predicate, exp.Any):
71        predicate = predicate.find_ancestor(exp.EQ)
72
73        if not predicate or parent_select is not predicate.parent_select:
74            return
75
76    column = _other_operand(predicate)
77    value = select.selects[0]
78
79    on = exp.condition(f'{column} = "{alias}"."{value.alias}"')
80    _replace(predicate, f"NOT {on.right} IS NULL")
81
82    parent_select.join(
83        select.group_by(value.this, copy=False),
84        on=on,
85        join_type="LEFT",
86        join_alias=alias,
87        copy=False,
88    )
def decorrelate(select, parent_select, external_columns, sequence):
 91def decorrelate(select, parent_select, external_columns, sequence):
 92    where = select.args.get("where")
 93
 94    if not where or where.find(exp.Or) or select.find(exp.Limit, exp.Offset):
 95        return
 96
 97    table_alias = _alias(sequence)
 98    keys = []
 99
100    # for all external columns in the where statement, find the relevant predicate
101    # keys to convert it into a join
102    for column in external_columns:
103        if column.find_ancestor(exp.Where) is not where:
104            return
105
106        predicate = column.find_ancestor(exp.Predicate)
107
108        if not predicate or predicate.find_ancestor(exp.Where) is not where:
109            return
110
111        if isinstance(predicate, exp.Binary):
112            key = (
113                predicate.right
114                if any(node is column for node, *_ in predicate.left.walk())
115                else predicate.left
116            )
117        else:
118            return
119
120        keys.append((key, column, predicate))
121
122    if not any(isinstance(predicate, exp.EQ) for *_, predicate in keys):
123        return
124
125    is_subquery_projection = any(
126        node is select.parent for node in parent_select.selects if isinstance(node, exp.Subquery)
127    )
128
129    value = select.selects[0]
130    key_aliases = {}
131    group_by = []
132
133    for key, _, predicate in keys:
134        # if we filter on the value of the subquery, it needs to be unique
135        if key == value.this:
136            key_aliases[key] = value.alias
137            group_by.append(key)
138        else:
139            if key not in key_aliases:
140                key_aliases[key] = _alias(sequence)
141            # all predicates that are equalities must also be in the unique
142            # so that we don't do a many to many join
143            if isinstance(predicate, exp.EQ) and key not in group_by:
144                group_by.append(key)
145
146    parent_predicate = select.find_ancestor(exp.Predicate)
147
148    # if the value of the subquery is not an agg or a key, we need to collect it into an array
149    # so that it can be grouped. For subquery projections, we use a MAX aggregation instead.
150    agg_func = exp.Max if is_subquery_projection else exp.ArrayAgg
151    if not value.find(exp.AggFunc) and value.this not in group_by:
152        select.select(
153            exp.alias_(agg_func(this=value.this), value.alias, quoted=False),
154            append=False,
155            copy=False,
156        )
157
158    # exists queries should not have any selects as it only checks if there are any rows
159    # all selects will be added by the optimizer and only used for join keys
160    if isinstance(parent_predicate, exp.Exists):
161        select.args["expressions"] = []
162
163    for key, alias in key_aliases.items():
164        if key in group_by:
165            # add all keys to the projections of the subquery
166            # so that we can use it as a join key
167            if isinstance(parent_predicate, exp.Exists) or key != value.this:
168                select.select(f"{key} AS {alias}", copy=False)
169        else:
170            select.select(exp.alias_(agg_func(this=key.copy()), alias, quoted=False), copy=False)
171
172    alias = exp.column(value.alias, table_alias)
173    other = _other_operand(parent_predicate)
174
175    if isinstance(parent_predicate, exp.Exists):
176        alias = exp.column(list(key_aliases.values())[0], table_alias)
177        parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL")
178    elif isinstance(parent_predicate, exp.All):
179        parent_predicate = _replace(
180            parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})"
181        )
182    elif isinstance(parent_predicate, exp.Any):
183        if value.this in group_by:
184            parent_predicate = _replace(parent_predicate.parent, f"{other} = {alias}")
185        else:
186            parent_predicate = _replace(parent_predicate, f"ARRAY_ANY({alias}, _x -> _x = {other})")
187    elif isinstance(parent_predicate, exp.In):
188        if value.this in group_by:
189            parent_predicate = _replace(parent_predicate, f"{other} = {alias}")
190        else:
191            parent_predicate = _replace(
192                parent_predicate,
193                f"ARRAY_ANY({alias}, _x -> _x = {parent_predicate.this})",
194            )
195    else:
196        if is_subquery_projection:
197            alias = exp.alias_(alias, select.parent.alias)
198
199        # COUNT always returns 0 on empty datasets, so we need take that into consideration here
200        # by transforming all counts into 0 and using that as the coalesced value
201        if value.find(exp.Count):
202
203            def remove_aggs(node):
204                if isinstance(node, exp.Count):
205                    return exp.Literal.number(0)
206                elif isinstance(node, exp.AggFunc):
207                    return exp.null()
208                return node
209
210            alias = exp.Coalesce(
211                this=alias,
212                expressions=[value.this.transform(remove_aggs)],
213            )
214
215        select.parent.replace(alias)
216
217    for key, column, predicate in keys:
218        predicate.replace(exp.true())
219        nested = exp.column(key_aliases[key], table_alias)
220
221        if is_subquery_projection:
222            key.replace(nested)
223            continue
224
225        if key in group_by:
226            key.replace(nested)
227        elif isinstance(predicate, exp.EQ):
228            parent_predicate = _replace(
229                parent_predicate,
230                f"({parent_predicate} AND ARRAY_CONTAINS({nested}, {column}))",
231            )
232        else:
233            key.replace(exp.to_identifier("_x"))
234            parent_predicate = _replace(
235                parent_predicate,
236                f'({parent_predicate} AND ARRAY_ANY({nested}, "_x" -> {predicate}))',
237            )
238
239    parent_select.join(
240        select.group_by(*group_by, copy=False),
241        on=[predicate for *_, predicate in keys if isinstance(predicate, exp.EQ)],
242        join_type="LEFT",
243        join_alias=table_alias,
244        copy=False,
245    )