summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/pushdown_predicates.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer/pushdown_predicates.py')
-rw-r--r--sqlglot/optimizer/pushdown_predicates.py38
1 files changed, 25 insertions, 13 deletions
diff --git a/sqlglot/optimizer/pushdown_predicates.py b/sqlglot/optimizer/pushdown_predicates.py
index a070d70..9c8d71d 100644
--- a/sqlglot/optimizer/pushdown_predicates.py
+++ b/sqlglot/optimizer/pushdown_predicates.py
@@ -1,3 +1,5 @@
+from collections import defaultdict
+
from sqlglot import exp
from sqlglot.optimizer.normalize import normalized
from sqlglot.optimizer.scope import traverse_scope
@@ -20,22 +22,30 @@ def pushdown_predicates(expression):
Returns:
sqlglot.Expression: optimized expression
"""
- for scope in reversed(traverse_scope(expression)):
+ scope_ref_count = defaultdict(lambda: 0)
+ scopes = traverse_scope(expression)
+ scopes.reverse()
+
+ for scope in scopes:
+ for _, source in scope.selected_sources.values():
+ scope_ref_count[id(source)] += 1
+
+ for scope in scopes:
select = scope.expression
where = select.args.get("where")
if where:
- pushdown(where.this, scope.selected_sources)
+ pushdown(where.this, scope.selected_sources, scope_ref_count)
# joins should only pushdown into itself, not to other joins
# so we limit the selected sources to only itself
for join in select.args.get("joins") or []:
name = join.this.alias_or_name
- pushdown(join.args.get("on"), {name: scope.selected_sources[name]})
+ pushdown(join.args.get("on"), {name: scope.selected_sources[name]}, scope_ref_count)
return expression
-def pushdown(condition, sources):
+def pushdown(condition, sources, scope_ref_count):
if not condition:
return
@@ -45,17 +55,17 @@ def pushdown(condition, sources):
predicates = list(condition.flatten() if isinstance(condition, exp.And if cnf_like else exp.Or) else [condition])
if cnf_like:
- pushdown_cnf(predicates, sources)
+ pushdown_cnf(predicates, sources, scope_ref_count)
else:
- pushdown_dnf(predicates, sources)
+ pushdown_dnf(predicates, sources, scope_ref_count)
-def pushdown_cnf(predicates, scope):
+def pushdown_cnf(predicates, scope, scope_ref_count):
"""
If the predicates are in CNF like form, we can simply replace each block in the parent.
"""
for predicate in predicates:
- for node in nodes_for_predicate(predicate, scope).values():
+ for node in nodes_for_predicate(predicate, scope, scope_ref_count).values():
if isinstance(node, exp.Join):
predicate.replace(exp.TRUE)
node.on(predicate, copy=False)
@@ -65,7 +75,7 @@ def pushdown_cnf(predicates, scope):
node.where(replace_aliases(node, predicate), copy=False)
-def pushdown_dnf(predicates, scope):
+def pushdown_dnf(predicates, scope, scope_ref_count):
"""
If the predicates are in DNF form, we can only push down conditions that are in all blocks.
Additionally, we can't remove predicates from their original form.
@@ -91,7 +101,7 @@ def pushdown_dnf(predicates, scope):
# (a.x AND and a.y AND b.x) OR (a.z AND c.y) -> (a.x AND a.y) OR (a.z)
for table in sorted(pushdown_tables):
for predicate in predicates:
- nodes = nodes_for_predicate(predicate, scope)
+ nodes = nodes_for_predicate(predicate, scope, scope_ref_count)
if table not in nodes:
continue
@@ -120,7 +130,7 @@ def pushdown_dnf(predicates, scope):
node.where(replace_aliases(node, predicate), copy=False)
-def nodes_for_predicate(predicate, sources):
+def nodes_for_predicate(predicate, sources, scope_ref_count):
nodes = {}
tables = exp.column_table_names(predicate)
where_condition = isinstance(predicate.find_ancestor(exp.Join, exp.Where), exp.Where)
@@ -133,7 +143,7 @@ def nodes_for_predicate(predicate, sources):
if node and where_condition:
node = node.find_ancestor(exp.Join, exp.From)
- # a node can reference a CTE which should be push down
+ # a node can reference a CTE which should be pushed down
if isinstance(node, exp.From) and not isinstance(source, exp.Table):
node = source.expression
@@ -142,7 +152,9 @@ def nodes_for_predicate(predicate, sources):
return {}
nodes[table] = node
elif isinstance(node, exp.Select) and len(tables) == 1:
- if not node.args.get("group"):
+ # we can't push down predicates to select statements if they are referenced in
+ # multiple places.
+ if not node.args.get("group") and scope_ref_count[id(source)] < 2:
nodes[table] = node
return nodes