summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/unnest_subqueries.py
blob: 55c81c5664ec6f12ce20f28645d3da3695fc66d9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
import itertools

from sqlglot import exp
from sqlglot.optimizer.scope import traverse_scope


def unnest_subqueries(expression):
    """
    Rewrite sqlglot AST to convert some predicates with subqueries into joins.

    Convert the subquery into a group by so it is not a many to many left join.
    Unnesting can only occur if the subquery does not have LIMIT or OFFSET.
    Unnesting non correlated subqueries only happens on IN statements or = ANY statements.

    Example:
        >>> import sqlglot
        >>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ")
        >>> unnest_subqueries(expression).sql()
        'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a)\
 AS "_u_0" ON x.a = "_u_0".a WHERE ("_u_0".a = 1 AND NOT "_u_0".a IS NULL)'

    Args:
        expression (sqlglot.Expression): expression to unnest
    Returns:
        sqlglot.Expression: unnested expression
    """
    sequence = itertools.count()

    for scope in traverse_scope(expression):
        select = scope.expression
        parent = select.parent_select
        if scope.external_columns:
            decorrelate(select, parent, scope.external_columns, sequence)
        else:
            unnest(select, parent, sequence)

    return expression


def unnest(select, parent_select, sequence):
    predicate = select.find_ancestor(exp.In, exp.Any)

    if not predicate or parent_select is not predicate.parent_select:
        return

    if len(select.selects) > 1 or select.find(exp.Limit, exp.Offset):
        return

    if isinstance(predicate, exp.Any):
        predicate = predicate.find_ancestor(exp.EQ)

        if not predicate or parent_select is not predicate.parent_select:
            return

    column = _other_operand(predicate)
    value = select.selects[0]
    alias = _alias(sequence)

    on = exp.condition(f'{column} = "{alias}"."{value.alias}"')
    _replace(predicate, f"NOT {on.right} IS NULL")

    parent_select.join(
        select.group_by(value.this, copy=False),
        on=on,
        join_type="LEFT",
        join_alias=alias,
        copy=False,
    )


def decorrelate(select, parent_select, external_columns, sequence):
    where = select.args.get("where")

    if not where or where.find(exp.Or) or select.find(exp.Limit, exp.Offset):
        return

    table_alias = _alias(sequence)
    keys = []

    # for all external columns in the where statement,
    # split out the relevant data to convert it into a join
    for column in external_columns:
        if column.find_ancestor(exp.Where) is not where:
            return

        predicate = column.find_ancestor(exp.Predicate)

        if not predicate or predicate.find_ancestor(exp.Where) is not where:
            return

        if isinstance(predicate, exp.Binary):
            key = (
                predicate.right
                if any(node is column for node, *_ in predicate.left.walk())
                else predicate.left
            )
        else:
            return

        keys.append((key, column, predicate))

    if not any(isinstance(predicate, exp.EQ) for *_, predicate in keys):
        return

    value = select.selects[0]
    key_aliases = {}
    group_by = []

    for key, _, predicate in keys:
        # if we filter on the value of the subquery, it needs to be unique
        if key == value.this:
            key_aliases[key] = value.alias
            group_by.append(key)
        else:
            if key not in key_aliases:
                key_aliases[key] = _alias(sequence)
            # all predicates that are equalities must also be in the unique
            # so that we don't do a many to many join
            if isinstance(predicate, exp.EQ) and key not in group_by:
                group_by.append(key)

    parent_predicate = select.find_ancestor(exp.Predicate)

    # if the value of the subquery is not an agg or a key, we need to collect it into an array
    # so that it can be grouped
    if not value.find(exp.AggFunc) and value.this not in group_by:
        select.select(
            f"ARRAY_AGG({value.this}) AS {value.alias}", append=False, copy=False
        )

    # exists queries should not have any selects as it only checks if there are any rows
    # all selects will be added by the optimizer and only used for join keys
    if isinstance(parent_predicate, exp.Exists):
        select.args["expressions"] = []

    for key, alias in key_aliases.items():
        if key in group_by:
            # add all keys to the projections of the subquery
            # so that we can use it as a join key
            if isinstance(parent_predicate, exp.Exists) or key != value.this:
                select.select(f"{key} AS {alias}", copy=False)
        else:
            select.select(f"ARRAY_AGG({key}) AS {alias}", copy=False)

    alias = exp.column(value.alias, table_alias)
    other = _other_operand(parent_predicate)

    if isinstance(parent_predicate, exp.Exists):
        if value.this in group_by:
            parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL")
        else:
            parent_predicate = _replace(parent_predicate, "TRUE")
    elif isinstance(parent_predicate, exp.All):
        parent_predicate = _replace(
            parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})"
        )
    elif isinstance(parent_predicate, exp.Any):
        if value.this in group_by:
            parent_predicate = _replace(parent_predicate.parent, f"{other} = {alias}")
        else:
            parent_predicate = _replace(
                parent_predicate, f"ARRAY_ANY({alias}, _x -> _x = {other})"
            )
    elif isinstance(parent_predicate, exp.In):
        if value.this in group_by:
            parent_predicate = _replace(parent_predicate, f"{other} = {alias}")
        else:
            parent_predicate = _replace(
                parent_predicate,
                f"ARRAY_ANY({alias}, _x -> _x = {parent_predicate.this})",
            )
    else:
        select.parent.replace(alias)

    for key, column, predicate in keys:
        predicate.replace(exp.TRUE)
        nested = exp.column(key_aliases[key], table_alias)

        if key in group_by:
            key.replace(nested)
            parent_predicate = _replace(
                parent_predicate, f"({parent_predicate} AND NOT {nested} IS NULL)"
            )
        elif isinstance(predicate, exp.EQ):
            parent_predicate = _replace(
                parent_predicate,
                f"({parent_predicate} AND ARRAY_CONTAINS({nested}, {column}))",
            )
        else:
            key.replace(exp.to_identifier("_x"))
            parent_predicate = _replace(
                parent_predicate,
                f'({parent_predicate} AND ARRAY_ANY({nested}, "_x" -> {predicate}))',
            )

    parent_select.join(
        select.group_by(*group_by, copy=False),
        on=[predicate for *_, predicate in keys if isinstance(predicate, exp.EQ)],
        join_type="LEFT",
        join_alias=table_alias,
        copy=False,
    )


def _alias(sequence):
    return f"_u_{next(sequence)}"


def _replace(expression, condition):
    return expression.replace(exp.condition(condition))


def _other_operand(expression):
    if isinstance(expression, exp.In):
        return expression.this

    if isinstance(expression, exp.Binary):
        return expression.right if expression.arg_key == "this" else expression.left

    return None