diff options
Diffstat (limited to 'sqlglot/optimizer/pushdown_predicates.py')
-rw-r--r-- | sqlglot/optimizer/pushdown_predicates.py | 50 |
1 files changed, 23 insertions, 27 deletions
diff --git a/sqlglot/optimizer/pushdown_predicates.py b/sqlglot/optimizer/pushdown_predicates.py index 10ff13a..12c3b89 100644 --- a/sqlglot/optimizer/pushdown_predicates.py +++ b/sqlglot/optimizer/pushdown_predicates.py @@ -30,13 +30,18 @@ def pushdown_predicates(expression, dialect=None): where = select.args.get("where") if where: selected_sources = scope.selected_sources + join_index = { + join.alias_or_name: i for i, join in enumerate(select.args.get("joins") or []) + } + # a right join can only push down to itself and not the source FROM table for k, (node, source) in selected_sources.items(): parent = node.find_ancestor(exp.Join, exp.From) if isinstance(parent, exp.Join) and parent.side == "RIGHT": selected_sources = {k: (node, source)} break - pushdown(where.this, selected_sources, scope_ref_count, dialect) + + pushdown(where.this, selected_sources, scope_ref_count, dialect, join_index) # joins should only pushdown into itself, not to other joins # so we limit the selected sources to only itself @@ -53,7 +58,7 @@ def pushdown_predicates(expression, dialect=None): return expression -def pushdown(condition, sources, scope_ref_count, dialect): +def pushdown(condition, sources, scope_ref_count, dialect, join_index=None): if not condition: return @@ -67,21 +72,28 @@ def pushdown(condition, sources, scope_ref_count, dialect): ) if cnf_like: - pushdown_cnf(predicates, sources, scope_ref_count) + pushdown_cnf(predicates, sources, scope_ref_count, join_index=join_index) else: pushdown_dnf(predicates, sources, scope_ref_count) -def pushdown_cnf(predicates, scope, scope_ref_count): +def pushdown_cnf(predicates, scope, scope_ref_count, join_index=None): """ If the predicates are in CNF like form, we can simply replace each block in the parent. """ + join_index = join_index or {} for predicate in predicates: for node in nodes_for_predicate(predicate, scope, scope_ref_count).values(): if isinstance(node, exp.Join): - predicate.replace(exp.true()) - node.on(predicate, copy=False) - break + name = node.alias_or_name + predicate_tables = exp.column_table_names(predicate, name) + + # Don't push the predicate if it references tables that appear in later joins + this_index = join_index[name] + if all(join_index.get(table, -1) < this_index for table in predicate_tables): + predicate.replace(exp.true()) + node.on(predicate, copy=False) + break if isinstance(node, exp.Select): predicate.replace(exp.true()) inner_predicate = replace_aliases(node, predicate) @@ -112,9 +124,7 @@ def pushdown_dnf(predicates, scope, scope_ref_count): conditions = {} - # for every pushdown table, find all related conditions in all predicates - # combine them with ORS - # (a.x AND and a.y AND b.x) OR (a.z AND c.y) -> (a.x AND a.y) OR (a.z) + # pushdown all predicates to their respective nodes for table in sorted(pushdown_tables): for predicate in predicates: nodes = nodes_for_predicate(predicate, scope, scope_ref_count) @@ -122,23 +132,9 @@ def pushdown_dnf(predicates, scope, scope_ref_count): if table not in nodes: continue - predicate_condition = None - - for column in predicate.find_all(exp.Column): - if column.table == table: - condition = column.find_ancestor(exp.Condition) - predicate_condition = ( - exp.and_(predicate_condition, condition) - if predicate_condition - else condition - ) - - if predicate_condition: - conditions[table] = ( - exp.or_(conditions[table], predicate_condition) - if table in conditions - else predicate_condition - ) + conditions[table] = ( + exp.or_(conditions[table], predicate) if table in conditions else predicate + ) for name, node in nodes.items(): if name not in conditions: |