1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
|
import itertools
from sqlglot import exp
from sqlglot.optimizer.scope import traverse_scope
def unnest_subqueries(expression):
"""
Rewrite sqlglot AST to convert some predicates with subqueries into joins.
Convert the subquery into a group by so it is not a many to many left join.
Unnesting can only occur if the subquery does not have LIMIT or OFFSET.
Unnesting non correlated subqueries only happens on IN statements or = ANY statements.
Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ")
>>> unnest_subqueries(expression).sql()
'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a)\
AS "_u_0" ON x.a = "_u_0".a WHERE ("_u_0".a = 1 AND NOT "_u_0".a IS NULL)'
Args:
expression (sqlglot.Expression): expression to unnest
Returns:
sqlglot.Expression: unnested expression
"""
sequence = itertools.count()
for scope in traverse_scope(expression):
select = scope.expression
parent = select.parent_select
if scope.external_columns:
decorrelate(select, parent, scope.external_columns, sequence)
else:
unnest(select, parent, sequence)
return expression
def unnest(select, parent_select, sequence):
predicate = select.find_ancestor(exp.In, exp.Any)
if not predicate or parent_select is not predicate.parent_select:
return
if len(select.selects) > 1 or select.find(exp.Limit, exp.Offset):
return
if isinstance(predicate, exp.Any):
predicate = predicate.find_ancestor(exp.EQ)
if not predicate or parent_select is not predicate.parent_select:
return
column = _other_operand(predicate)
value = select.selects[0]
alias = _alias(sequence)
on = exp.condition(f'{column} = "{alias}"."{value.alias}"')
_replace(predicate, f"NOT {on.right} IS NULL")
parent_select.join(
select.group_by(value.this, copy=False),
on=on,
join_type="LEFT",
join_alias=alias,
copy=False,
)
def decorrelate(select, parent_select, external_columns, sequence):
where = select.args.get("where")
if not where or where.find(exp.Or) or select.find(exp.Limit, exp.Offset):
return
table_alias = _alias(sequence)
keys = []
# for all external columns in the where statement,
# split out the relevant data to convert it into a join
for column in external_columns:
if column.find_ancestor(exp.Where) is not where:
return
predicate = column.find_ancestor(exp.Predicate)
if not predicate or predicate.find_ancestor(exp.Where) is not where:
return
if isinstance(predicate, exp.Binary):
key = (
predicate.right
if any(node is column for node, *_ in predicate.left.walk())
else predicate.left
)
else:
return
keys.append((key, column, predicate))
if not any(isinstance(predicate, exp.EQ) for *_, predicate in keys):
return
value = select.selects[0]
key_aliases = {}
group_by = []
for key, _, predicate in keys:
# if we filter on the value of the subquery, it needs to be unique
if key == value.this:
key_aliases[key] = value.alias
group_by.append(key)
else:
if key not in key_aliases:
key_aliases[key] = _alias(sequence)
# all predicates that are equalities must also be in the unique
# so that we don't do a many to many join
if isinstance(predicate, exp.EQ) and key not in group_by:
group_by.append(key)
parent_predicate = select.find_ancestor(exp.Predicate)
# if the value of the subquery is not an agg or a key, we need to collect it into an array
# so that it can be grouped
if not value.find(exp.AggFunc) and value.this not in group_by:
select.select(f"ARRAY_AGG({value.this}) AS {value.alias}", append=False, copy=False)
# exists queries should not have any selects as it only checks if there are any rows
# all selects will be added by the optimizer and only used for join keys
if isinstance(parent_predicate, exp.Exists):
select.args["expressions"] = []
for key, alias in key_aliases.items():
if key in group_by:
# add all keys to the projections of the subquery
# so that we can use it as a join key
if isinstance(parent_predicate, exp.Exists) or key != value.this:
select.select(f"{key} AS {alias}", copy=False)
else:
select.select(f"ARRAY_AGG({key}) AS {alias}", copy=False)
alias = exp.column(value.alias, table_alias)
other = _other_operand(parent_predicate)
if isinstance(parent_predicate, exp.Exists):
if value.this in group_by:
parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL")
else:
parent_predicate = _replace(parent_predicate, "TRUE")
elif isinstance(parent_predicate, exp.All):
parent_predicate = _replace(
parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})"
)
elif isinstance(parent_predicate, exp.Any):
if value.this in group_by:
parent_predicate = _replace(parent_predicate.parent, f"{other} = {alias}")
else:
parent_predicate = _replace(parent_predicate, f"ARRAY_ANY({alias}, _x -> _x = {other})")
elif isinstance(parent_predicate, exp.In):
if value.this in group_by:
parent_predicate = _replace(parent_predicate, f"{other} = {alias}")
else:
parent_predicate = _replace(
parent_predicate,
f"ARRAY_ANY({alias}, _x -> _x = {parent_predicate.this})",
)
else:
select.parent.replace(alias)
for key, column, predicate in keys:
predicate.replace(exp.TRUE)
nested = exp.column(key_aliases[key], table_alias)
if key in group_by:
key.replace(nested)
parent_predicate = _replace(
parent_predicate, f"({parent_predicate} AND NOT {nested} IS NULL)"
)
elif isinstance(predicate, exp.EQ):
parent_predicate = _replace(
parent_predicate,
f"({parent_predicate} AND ARRAY_CONTAINS({nested}, {column}))",
)
else:
key.replace(exp.to_identifier("_x"))
parent_predicate = _replace(
parent_predicate,
f'({parent_predicate} AND ARRAY_ANY({nested}, "_x" -> {predicate}))',
)
parent_select.join(
select.group_by(*group_by, copy=False),
on=[predicate for *_, predicate in keys if isinstance(predicate, exp.EQ)],
join_type="LEFT",
join_alias=table_alias,
copy=False,
)
def _alias(sequence):
return f"_u_{next(sequence)}"
def _replace(expression, condition):
return expression.replace(exp.condition(condition))
def _other_operand(expression):
if isinstance(expression, exp.In):
return expression.this
if isinstance(expression, exp.Binary):
return expression.right if expression.arg_key == "this" else expression.left
return None
|