1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
|
from sqlglot import exp
from sqlglot.optimizer.normalize import normalized
from sqlglot.optimizer.scope import traverse_scope
from sqlglot.optimizer.simplify import simplify
def pushdown_predicates(expression):
"""
Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS
Example:
>>> import sqlglot
>>> sql = "SELECT * FROM (SELECT * FROM x AS x) AS y WHERE y.a = 1"
>>> expression = sqlglot.parse_one(sql)
>>> pushdown_predicates(expression).sql()
'SELECT * FROM (SELECT * FROM x AS x WHERE y.a = 1) AS y WHERE TRUE'
Args:
expression (sqlglot.Expression): expression to optimize
Returns:
sqlglot.Expression: optimized expression
"""
for scope in reversed(traverse_scope(expression)):
select = scope.expression
where = select.args.get("where")
if where:
pushdown(where.this, scope.selected_sources)
# joins should only pushdown into itself, not to other joins
# so we limit the selected sources to only itself
for join in select.args.get("joins") or []:
name = join.this.alias_or_name
pushdown(join.args.get("on"), {name: scope.selected_sources[name]})
return expression
def pushdown(condition, sources):
if not condition:
return
condition = condition.replace(simplify(condition))
cnf_like = normalized(condition) or not normalized(condition, dnf=True)
predicates = list(
condition.flatten()
if isinstance(condition, exp.And if cnf_like else exp.Or)
else [condition]
)
if cnf_like:
pushdown_cnf(predicates, sources)
else:
pushdown_dnf(predicates, sources)
def pushdown_cnf(predicates, scope):
"""
If the predicates are in CNF like form, we can simply replace each block in the parent.
"""
for predicate in predicates:
for node in nodes_for_predicate(predicate, scope).values():
if isinstance(node, exp.Join):
predicate.replace(exp.TRUE)
node.on(predicate, copy=False)
break
if isinstance(node, exp.Select):
predicate.replace(exp.TRUE)
node.where(replace_aliases(node, predicate), copy=False)
def pushdown_dnf(predicates, scope):
"""
If the predicates are in DNF form, we can only push down conditions that are in all blocks.
Additionally, we can't remove predicates from their original form.
"""
# find all the tables that can be pushdown too
# these are tables that are referenced in all blocks of a DNF
# (a.x AND b.x) OR (a.y AND c.y)
# only table a can be push down
pushdown_tables = set()
for a in predicates:
a_tables = set(exp.column_table_names(a))
for b in predicates:
a_tables &= set(exp.column_table_names(b))
pushdown_tables.update(a_tables)
conditions = {}
# for every pushdown table, find all related conditions in all predicates
# combine them with ORS
# (a.x AND and a.y AND b.x) OR (a.z AND c.y) -> (a.x AND a.y) OR (a.z)
for table in sorted(pushdown_tables):
for predicate in predicates:
nodes = nodes_for_predicate(predicate, scope)
if table not in nodes:
continue
predicate_condition = None
for column in predicate.find_all(exp.Column):
if column.table == table:
condition = column.find_ancestor(exp.Condition)
predicate_condition = (
exp.and_(predicate_condition, condition)
if predicate_condition
else condition
)
if predicate_condition:
conditions[table] = (
exp.or_(conditions[table], predicate_condition)
if table in conditions
else predicate_condition
)
for name, node in nodes.items():
if name not in conditions:
continue
predicate = conditions[name]
if isinstance(node, exp.Join):
node.on(predicate, copy=False)
elif isinstance(node, exp.Select):
node.where(replace_aliases(node, predicate), copy=False)
def nodes_for_predicate(predicate, sources):
nodes = {}
tables = exp.column_table_names(predicate)
where_condition = isinstance(
predicate.find_ancestor(exp.Join, exp.Where), exp.Where
)
for table in tables:
node, source = sources.get(table) or (None, None)
# if the predicate is in a where statement we can try to push it down
# we want to find the root join or from statement
if node and where_condition:
node = node.find_ancestor(exp.Join, exp.From)
# a node can reference a CTE which should be push down
if isinstance(node, exp.From) and not isinstance(source, exp.Table):
node = source.expression
if isinstance(node, exp.Join):
if node.side:
return {}
nodes[table] = node
elif isinstance(node, exp.Select) and len(tables) == 1:
if not node.args.get("group"):
nodes[table] = node
return nodes
def replace_aliases(source, predicate):
aliases = {}
for select in source.selects:
if isinstance(select, exp.Alias):
aliases[select.alias] = select.this
else:
aliases[select.name] = select
def _replace_alias(column):
if isinstance(column, exp.Column) and column.name in aliases:
return aliases[column.name]
return column
return predicate.transform(_replace_alias)
|