summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/pushdown_predicates.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer/pushdown_predicates.py')
-rw-r--r--sqlglot/optimizer/pushdown_predicates.py50
1 files changed, 23 insertions, 27 deletions
diff --git a/sqlglot/optimizer/pushdown_predicates.py b/sqlglot/optimizer/pushdown_predicates.py
index 10ff13a..12c3b89 100644
--- a/sqlglot/optimizer/pushdown_predicates.py
+++ b/sqlglot/optimizer/pushdown_predicates.py
@@ -30,13 +30,18 @@ def pushdown_predicates(expression, dialect=None):
where = select.args.get("where")
if where:
selected_sources = scope.selected_sources
+ join_index = {
+ join.alias_or_name: i for i, join in enumerate(select.args.get("joins") or [])
+ }
+
# a right join can only push down to itself and not the source FROM table
for k, (node, source) in selected_sources.items():
parent = node.find_ancestor(exp.Join, exp.From)
if isinstance(parent, exp.Join) and parent.side == "RIGHT":
selected_sources = {k: (node, source)}
break
- pushdown(where.this, selected_sources, scope_ref_count, dialect)
+
+ pushdown(where.this, selected_sources, scope_ref_count, dialect, join_index)
# joins should only pushdown into itself, not to other joins
# so we limit the selected sources to only itself
@@ -53,7 +58,7 @@ def pushdown_predicates(expression, dialect=None):
return expression
-def pushdown(condition, sources, scope_ref_count, dialect):
+def pushdown(condition, sources, scope_ref_count, dialect, join_index=None):
if not condition:
return
@@ -67,21 +72,28 @@ def pushdown(condition, sources, scope_ref_count, dialect):
)
if cnf_like:
- pushdown_cnf(predicates, sources, scope_ref_count)
+ pushdown_cnf(predicates, sources, scope_ref_count, join_index=join_index)
else:
pushdown_dnf(predicates, sources, scope_ref_count)
-def pushdown_cnf(predicates, scope, scope_ref_count):
+def pushdown_cnf(predicates, scope, scope_ref_count, join_index=None):
"""
If the predicates are in CNF like form, we can simply replace each block in the parent.
"""
+ join_index = join_index or {}
for predicate in predicates:
for node in nodes_for_predicate(predicate, scope, scope_ref_count).values():
if isinstance(node, exp.Join):
- predicate.replace(exp.true())
- node.on(predicate, copy=False)
- break
+ name = node.alias_or_name
+ predicate_tables = exp.column_table_names(predicate, name)
+
+ # Don't push the predicate if it references tables that appear in later joins
+ this_index = join_index[name]
+ if all(join_index.get(table, -1) < this_index for table in predicate_tables):
+ predicate.replace(exp.true())
+ node.on(predicate, copy=False)
+ break
if isinstance(node, exp.Select):
predicate.replace(exp.true())
inner_predicate = replace_aliases(node, predicate)
@@ -112,9 +124,7 @@ def pushdown_dnf(predicates, scope, scope_ref_count):
conditions = {}
- # for every pushdown table, find all related conditions in all predicates
- # combine them with ORS
- # (a.x AND and a.y AND b.x) OR (a.z AND c.y) -> (a.x AND a.y) OR (a.z)
+ # pushdown all predicates to their respective nodes
for table in sorted(pushdown_tables):
for predicate in predicates:
nodes = nodes_for_predicate(predicate, scope, scope_ref_count)
@@ -122,23 +132,9 @@ def pushdown_dnf(predicates, scope, scope_ref_count):
if table not in nodes:
continue
- predicate_condition = None
-
- for column in predicate.find_all(exp.Column):
- if column.table == table:
- condition = column.find_ancestor(exp.Condition)
- predicate_condition = (
- exp.and_(predicate_condition, condition)
- if predicate_condition
- else condition
- )
-
- if predicate_condition:
- conditions[table] = (
- exp.or_(conditions[table], predicate_condition)
- if table in conditions
- else predicate_condition
- )
+ conditions[table] = (
+ exp.or_(conditions[table], predicate) if table in conditions else predicate
+ )
for name, node in nodes.items():
if name not in conditions: