Edit on GitHub

sqlglot.optimizer.unnest_subqueries

  1from sqlglot import exp
  2from sqlglot.helper import name_sequence
  3from sqlglot.optimizer.scope import ScopeType, traverse_scope
  4
  5
  6def unnest_subqueries(expression):
  7    """
  8    Rewrite sqlglot AST to convert some predicates with subqueries into joins.
  9
 10    Convert scalar subqueries into cross joins.
 11    Convert correlated or vectorized subqueries into a group by so it is not a many to many left join.
 12
 13    Example:
 14        >>> import sqlglot
 15        >>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ")
 16        >>> unnest_subqueries(expression).sql()
 17        'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE _u_0.a = 1'
 18
 19    Args:
 20        expression (sqlglot.Expression): expression to unnest
 21    Returns:
 22        sqlglot.Expression: unnested expression
 23    """
 24    next_alias_name = name_sequence("_u_")
 25
 26    for scope in traverse_scope(expression):
 27        select = scope.expression
 28        parent = select.parent_select
 29        if not parent:
 30            continue
 31        if scope.external_columns:
 32            decorrelate(select, parent, scope.external_columns, next_alias_name)
 33        elif scope.scope_type == ScopeType.SUBQUERY:
 34            unnest(select, parent, next_alias_name)
 35
 36    return expression
 37
 38
 39def unnest(select, parent_select, next_alias_name):
 40    if len(select.selects) > 1:
 41        return
 42
 43    predicate = select.find_ancestor(exp.Condition)
 44    alias = next_alias_name()
 45
 46    if not predicate or parent_select is not predicate.parent_select:
 47        return
 48
 49    # this subquery returns a scalar and can just be converted to a cross join
 50    if not isinstance(predicate, (exp.In, exp.Any)):
 51        having = predicate.find_ancestor(exp.Having)
 52        column = exp.column(select.selects[0].alias_or_name, alias)
 53        if having and having.parent_select is parent_select:
 54            column = exp.Max(this=column)
 55        _replace(select.parent, column)
 56
 57        parent_select.join(
 58            select,
 59            join_type="CROSS",
 60            join_alias=alias,
 61            copy=False,
 62        )
 63        return
 64
 65    if select.find(exp.Limit, exp.Offset):
 66        return
 67
 68    if isinstance(predicate, exp.Any):
 69        predicate = predicate.find_ancestor(exp.EQ)
 70
 71        if not predicate or parent_select is not predicate.parent_select:
 72            return
 73
 74    column = _other_operand(predicate)
 75    value = select.selects[0]
 76
 77    on = exp.condition(f'{column} = "{alias}"."{value.alias}"')
 78    _replace(predicate, f"NOT {on.right} IS NULL")
 79
 80    parent_select.join(
 81        select.group_by(value.this, copy=False),
 82        on=on,
 83        join_type="LEFT",
 84        join_alias=alias,
 85        copy=False,
 86    )
 87
 88
 89def decorrelate(select, parent_select, external_columns, next_alias_name):
 90    where = select.args.get("where")
 91
 92    if not where or where.find(exp.Or) or select.find(exp.Limit, exp.Offset):
 93        return
 94
 95    table_alias = next_alias_name()
 96    keys = []
 97
 98    # for all external columns in the where statement, find the relevant predicate
 99    # keys to convert it into a join
100    for column in external_columns:
101        if column.find_ancestor(exp.Where) is not where:
102            return
103
104        predicate = column.find_ancestor(exp.Predicate)
105
106        if not predicate or predicate.find_ancestor(exp.Where) is not where:
107            return
108
109        if isinstance(predicate, exp.Binary):
110            key = (
111                predicate.right
112                if any(node is column for node, *_ in predicate.left.walk())
113                else predicate.left
114            )
115        else:
116            return
117
118        keys.append((key, column, predicate))
119
120    if not any(isinstance(predicate, exp.EQ) for *_, predicate in keys):
121        return
122
123    is_subquery_projection = any(
124        node is select.parent for node in parent_select.selects if isinstance(node, exp.Subquery)
125    )
126
127    value = select.selects[0]
128    key_aliases = {}
129    group_by = []
130
131    for key, _, predicate in keys:
132        # if we filter on the value of the subquery, it needs to be unique
133        if key == value.this:
134            key_aliases[key] = value.alias
135            group_by.append(key)
136        else:
137            if key not in key_aliases:
138                key_aliases[key] = next_alias_name()
139            # all predicates that are equalities must also be in the unique
140            # so that we don't do a many to many join
141            if isinstance(predicate, exp.EQ) and key not in group_by:
142                group_by.append(key)
143
144    parent_predicate = select.find_ancestor(exp.Predicate)
145
146    # if the value of the subquery is not an agg or a key, we need to collect it into an array
147    # so that it can be grouped. For subquery projections, we use a MAX aggregation instead.
148    agg_func = exp.Max if is_subquery_projection else exp.ArrayAgg
149    if not value.find(exp.AggFunc) and value.this not in group_by:
150        select.select(
151            exp.alias_(agg_func(this=value.this), value.alias, quoted=False),
152            append=False,
153            copy=False,
154        )
155
156    # exists queries should not have any selects as it only checks if there are any rows
157    # all selects will be added by the optimizer and only used for join keys
158    if isinstance(parent_predicate, exp.Exists):
159        select.args["expressions"] = []
160
161    for key, alias in key_aliases.items():
162        if key in group_by:
163            # add all keys to the projections of the subquery
164            # so that we can use it as a join key
165            if isinstance(parent_predicate, exp.Exists) or key != value.this:
166                select.select(f"{key} AS {alias}", copy=False)
167        else:
168            select.select(exp.alias_(agg_func(this=key.copy()), alias, quoted=False), copy=False)
169
170    alias = exp.column(value.alias, table_alias)
171    other = _other_operand(parent_predicate)
172
173    if isinstance(parent_predicate, exp.Exists):
174        alias = exp.column(list(key_aliases.values())[0], table_alias)
175        parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL")
176    elif isinstance(parent_predicate, exp.All):
177        parent_predicate = _replace(
178            parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})"
179        )
180    elif isinstance(parent_predicate, exp.Any):
181        if value.this in group_by:
182            parent_predicate = _replace(parent_predicate.parent, f"{other} = {alias}")
183        else:
184            parent_predicate = _replace(parent_predicate, f"ARRAY_ANY({alias}, _x -> _x = {other})")
185    elif isinstance(parent_predicate, exp.In):
186        if value.this in group_by:
187            parent_predicate = _replace(parent_predicate, f"{other} = {alias}")
188        else:
189            parent_predicate = _replace(
190                parent_predicate,
191                f"ARRAY_ANY({alias}, _x -> _x = {parent_predicate.this})",
192            )
193    else:
194        if is_subquery_projection:
195            alias = exp.alias_(alias, select.parent.alias)
196
197        # COUNT always returns 0 on empty datasets, so we need take that into consideration here
198        # by transforming all counts into 0 and using that as the coalesced value
199        if value.find(exp.Count):
200
201            def remove_aggs(node):
202                if isinstance(node, exp.Count):
203                    return exp.Literal.number(0)
204                elif isinstance(node, exp.AggFunc):
205                    return exp.null()
206                return node
207
208            alias = exp.Coalesce(
209                this=alias,
210                expressions=[value.this.transform(remove_aggs)],
211            )
212
213        select.parent.replace(alias)
214
215    for key, column, predicate in keys:
216        predicate.replace(exp.true())
217        nested = exp.column(key_aliases[key], table_alias)
218
219        if is_subquery_projection:
220            key.replace(nested)
221            continue
222
223        if key in group_by:
224            key.replace(nested)
225        elif isinstance(predicate, exp.EQ):
226            parent_predicate = _replace(
227                parent_predicate,
228                f"({parent_predicate} AND ARRAY_CONTAINS({nested}, {column}))",
229            )
230        else:
231            key.replace(exp.to_identifier("_x"))
232            parent_predicate = _replace(
233                parent_predicate,
234                f'({parent_predicate} AND ARRAY_ANY({nested}, "_x" -> {predicate}))',
235            )
236
237    parent_select.join(
238        select.group_by(*group_by, copy=False),
239        on=[predicate for *_, predicate in keys if isinstance(predicate, exp.EQ)],
240        join_type="LEFT",
241        join_alias=table_alias,
242        copy=False,
243    )
244
245
246def _replace(expression, condition):
247    return expression.replace(exp.condition(condition))
248
249
250def _other_operand(expression):
251    if isinstance(expression, exp.In):
252        return expression.this
253
254    if isinstance(expression, (exp.Any, exp.All)):
255        return _other_operand(expression.parent)
256
257    if isinstance(expression, exp.Binary):
258        return (
259            expression.right
260            if isinstance(expression.left, (exp.Subquery, exp.Any, exp.Exists, exp.All))
261            else expression.left
262        )
263
264    return None
def unnest_subqueries(expression):
 7def unnest_subqueries(expression):
 8    """
 9    Rewrite sqlglot AST to convert some predicates with subqueries into joins.
10
11    Convert scalar subqueries into cross joins.
12    Convert correlated or vectorized subqueries into a group by so it is not a many to many left join.
13
14    Example:
15        >>> import sqlglot
16        >>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ")
17        >>> unnest_subqueries(expression).sql()
18        'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE _u_0.a = 1'
19
20    Args:
21        expression (sqlglot.Expression): expression to unnest
22    Returns:
23        sqlglot.Expression: unnested expression
24    """
25    next_alias_name = name_sequence("_u_")
26
27    for scope in traverse_scope(expression):
28        select = scope.expression
29        parent = select.parent_select
30        if not parent:
31            continue
32        if scope.external_columns:
33            decorrelate(select, parent, scope.external_columns, next_alias_name)
34        elif scope.scope_type == ScopeType.SUBQUERY:
35            unnest(select, parent, next_alias_name)
36
37    return expression

Rewrite sqlglot AST to convert some predicates with subqueries into joins.

Convert scalar subqueries into cross joins. Convert correlated or vectorized subqueries into a group by so it is not a many to many left join.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ")
>>> unnest_subqueries(expression).sql()
'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE _u_0.a = 1'
Arguments:
  • expression (sqlglot.Expression): expression to unnest
Returns:

sqlglot.Expression: unnested expression

def unnest(select, parent_select, next_alias_name):
40def unnest(select, parent_select, next_alias_name):
41    if len(select.selects) > 1:
42        return
43
44    predicate = select.find_ancestor(exp.Condition)
45    alias = next_alias_name()
46
47    if not predicate or parent_select is not predicate.parent_select:
48        return
49
50    # this subquery returns a scalar and can just be converted to a cross join
51    if not isinstance(predicate, (exp.In, exp.Any)):
52        having = predicate.find_ancestor(exp.Having)
53        column = exp.column(select.selects[0].alias_or_name, alias)
54        if having and having.parent_select is parent_select:
55            column = exp.Max(this=column)
56        _replace(select.parent, column)
57
58        parent_select.join(
59            select,
60            join_type="CROSS",
61            join_alias=alias,
62            copy=False,
63        )
64        return
65
66    if select.find(exp.Limit, exp.Offset):
67        return
68
69    if isinstance(predicate, exp.Any):
70        predicate = predicate.find_ancestor(exp.EQ)
71
72        if not predicate or parent_select is not predicate.parent_select:
73            return
74
75    column = _other_operand(predicate)
76    value = select.selects[0]
77
78    on = exp.condition(f'{column} = "{alias}"."{value.alias}"')
79    _replace(predicate, f"NOT {on.right} IS NULL")
80
81    parent_select.join(
82        select.group_by(value.this, copy=False),
83        on=on,
84        join_type="LEFT",
85        join_alias=alias,
86        copy=False,
87    )
def decorrelate(select, parent_select, external_columns, next_alias_name):
 90def decorrelate(select, parent_select, external_columns, next_alias_name):
 91    where = select.args.get("where")
 92
 93    if not where or where.find(exp.Or) or select.find(exp.Limit, exp.Offset):
 94        return
 95
 96    table_alias = next_alias_name()
 97    keys = []
 98
 99    # for all external columns in the where statement, find the relevant predicate
100    # keys to convert it into a join
101    for column in external_columns:
102        if column.find_ancestor(exp.Where) is not where:
103            return
104
105        predicate = column.find_ancestor(exp.Predicate)
106
107        if not predicate or predicate.find_ancestor(exp.Where) is not where:
108            return
109
110        if isinstance(predicate, exp.Binary):
111            key = (
112                predicate.right
113                if any(node is column for node, *_ in predicate.left.walk())
114                else predicate.left
115            )
116        else:
117            return
118
119        keys.append((key, column, predicate))
120
121    if not any(isinstance(predicate, exp.EQ) for *_, predicate in keys):
122        return
123
124    is_subquery_projection = any(
125        node is select.parent for node in parent_select.selects if isinstance(node, exp.Subquery)
126    )
127
128    value = select.selects[0]
129    key_aliases = {}
130    group_by = []
131
132    for key, _, predicate in keys:
133        # if we filter on the value of the subquery, it needs to be unique
134        if key == value.this:
135            key_aliases[key] = value.alias
136            group_by.append(key)
137        else:
138            if key not in key_aliases:
139                key_aliases[key] = next_alias_name()
140            # all predicates that are equalities must also be in the unique
141            # so that we don't do a many to many join
142            if isinstance(predicate, exp.EQ) and key not in group_by:
143                group_by.append(key)
144
145    parent_predicate = select.find_ancestor(exp.Predicate)
146
147    # if the value of the subquery is not an agg or a key, we need to collect it into an array
148    # so that it can be grouped. For subquery projections, we use a MAX aggregation instead.
149    agg_func = exp.Max if is_subquery_projection else exp.ArrayAgg
150    if not value.find(exp.AggFunc) and value.this not in group_by:
151        select.select(
152            exp.alias_(agg_func(this=value.this), value.alias, quoted=False),
153            append=False,
154            copy=False,
155        )
156
157    # exists queries should not have any selects as it only checks if there are any rows
158    # all selects will be added by the optimizer and only used for join keys
159    if isinstance(parent_predicate, exp.Exists):
160        select.args["expressions"] = []
161
162    for key, alias in key_aliases.items():
163        if key in group_by:
164            # add all keys to the projections of the subquery
165            # so that we can use it as a join key
166            if isinstance(parent_predicate, exp.Exists) or key != value.this:
167                select.select(f"{key} AS {alias}", copy=False)
168        else:
169            select.select(exp.alias_(agg_func(this=key.copy()), alias, quoted=False), copy=False)
170
171    alias = exp.column(value.alias, table_alias)
172    other = _other_operand(parent_predicate)
173
174    if isinstance(parent_predicate, exp.Exists):
175        alias = exp.column(list(key_aliases.values())[0], table_alias)
176        parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL")
177    elif isinstance(parent_predicate, exp.All):
178        parent_predicate = _replace(
179            parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})"
180        )
181    elif isinstance(parent_predicate, exp.Any):
182        if value.this in group_by:
183            parent_predicate = _replace(parent_predicate.parent, f"{other} = {alias}")
184        else:
185            parent_predicate = _replace(parent_predicate, f"ARRAY_ANY({alias}, _x -> _x = {other})")
186    elif isinstance(parent_predicate, exp.In):
187        if value.this in group_by:
188            parent_predicate = _replace(parent_predicate, f"{other} = {alias}")
189        else:
190            parent_predicate = _replace(
191                parent_predicate,
192                f"ARRAY_ANY({alias}, _x -> _x = {parent_predicate.this})",
193            )
194    else:
195        if is_subquery_projection:
196            alias = exp.alias_(alias, select.parent.alias)
197
198        # COUNT always returns 0 on empty datasets, so we need take that into consideration here
199        # by transforming all counts into 0 and using that as the coalesced value
200        if value.find(exp.Count):
201
202            def remove_aggs(node):
203                if isinstance(node, exp.Count):
204                    return exp.Literal.number(0)
205                elif isinstance(node, exp.AggFunc):
206                    return exp.null()
207                return node
208
209            alias = exp.Coalesce(
210                this=alias,
211                expressions=[value.this.transform(remove_aggs)],
212            )
213
214        select.parent.replace(alias)
215
216    for key, column, predicate in keys:
217        predicate.replace(exp.true())
218        nested = exp.column(key_aliases[key], table_alias)
219
220        if is_subquery_projection:
221            key.replace(nested)
222            continue
223
224        if key in group_by:
225            key.replace(nested)
226        elif isinstance(predicate, exp.EQ):
227            parent_predicate = _replace(
228                parent_predicate,
229                f"({parent_predicate} AND ARRAY_CONTAINS({nested}, {column}))",
230            )
231        else:
232            key.replace(exp.to_identifier("_x"))
233            parent_predicate = _replace(
234                parent_predicate,
235                f'({parent_predicate} AND ARRAY_ANY({nested}, "_x" -> {predicate}))',
236            )
237
238    parent_select.join(
239        select.group_by(*group_by, copy=False),
240        on=[predicate for *_, predicate in keys if isinstance(predicate, exp.EQ)],
241        join_type="LEFT",
242        join_alias=table_alias,
243        copy=False,
244    )