diff options
Diffstat (limited to 'sqlglot/optimizer/pushdown_predicates.py')
-rw-r--r-- | sqlglot/optimizer/pushdown_predicates.py | 38 |
1 files changed, 25 insertions, 13 deletions
diff --git a/sqlglot/optimizer/pushdown_predicates.py b/sqlglot/optimizer/pushdown_predicates.py index a070d70..9c8d71d 100644 --- a/sqlglot/optimizer/pushdown_predicates.py +++ b/sqlglot/optimizer/pushdown_predicates.py @@ -1,3 +1,5 @@ +from collections import defaultdict + from sqlglot import exp from sqlglot.optimizer.normalize import normalized from sqlglot.optimizer.scope import traverse_scope @@ -20,22 +22,30 @@ def pushdown_predicates(expression): Returns: sqlglot.Expression: optimized expression """ - for scope in reversed(traverse_scope(expression)): + scope_ref_count = defaultdict(lambda: 0) + scopes = traverse_scope(expression) + scopes.reverse() + + for scope in scopes: + for _, source in scope.selected_sources.values(): + scope_ref_count[id(source)] += 1 + + for scope in scopes: select = scope.expression where = select.args.get("where") if where: - pushdown(where.this, scope.selected_sources) + pushdown(where.this, scope.selected_sources, scope_ref_count) # joins should only pushdown into itself, not to other joins # so we limit the selected sources to only itself for join in select.args.get("joins") or []: name = join.this.alias_or_name - pushdown(join.args.get("on"), {name: scope.selected_sources[name]}) + pushdown(join.args.get("on"), {name: scope.selected_sources[name]}, scope_ref_count) return expression -def pushdown(condition, sources): +def pushdown(condition, sources, scope_ref_count): if not condition: return @@ -45,17 +55,17 @@ def pushdown(condition, sources): predicates = list(condition.flatten() if isinstance(condition, exp.And if cnf_like else exp.Or) else [condition]) if cnf_like: - pushdown_cnf(predicates, sources) + pushdown_cnf(predicates, sources, scope_ref_count) else: - pushdown_dnf(predicates, sources) + pushdown_dnf(predicates, sources, scope_ref_count) -def pushdown_cnf(predicates, scope): +def pushdown_cnf(predicates, scope, scope_ref_count): """ If the predicates are in CNF like form, we can simply replace each block in the parent. """ for predicate in predicates: - for node in nodes_for_predicate(predicate, scope).values(): + for node in nodes_for_predicate(predicate, scope, scope_ref_count).values(): if isinstance(node, exp.Join): predicate.replace(exp.TRUE) node.on(predicate, copy=False) @@ -65,7 +75,7 @@ def pushdown_cnf(predicates, scope): node.where(replace_aliases(node, predicate), copy=False) -def pushdown_dnf(predicates, scope): +def pushdown_dnf(predicates, scope, scope_ref_count): """ If the predicates are in DNF form, we can only push down conditions that are in all blocks. Additionally, we can't remove predicates from their original form. @@ -91,7 +101,7 @@ def pushdown_dnf(predicates, scope): # (a.x AND and a.y AND b.x) OR (a.z AND c.y) -> (a.x AND a.y) OR (a.z) for table in sorted(pushdown_tables): for predicate in predicates: - nodes = nodes_for_predicate(predicate, scope) + nodes = nodes_for_predicate(predicate, scope, scope_ref_count) if table not in nodes: continue @@ -120,7 +130,7 @@ def pushdown_dnf(predicates, scope): node.where(replace_aliases(node, predicate), copy=False) -def nodes_for_predicate(predicate, sources): +def nodes_for_predicate(predicate, sources, scope_ref_count): nodes = {} tables = exp.column_table_names(predicate) where_condition = isinstance(predicate.find_ancestor(exp.Join, exp.Where), exp.Where) @@ -133,7 +143,7 @@ def nodes_for_predicate(predicate, sources): if node and where_condition: node = node.find_ancestor(exp.Join, exp.From) - # a node can reference a CTE which should be push down + # a node can reference a CTE which should be pushed down if isinstance(node, exp.From) and not isinstance(source, exp.Table): node = source.expression @@ -142,7 +152,9 @@ def nodes_for_predicate(predicate, sources): return {} nodes[table] = node elif isinstance(node, exp.Select) and len(tables) == 1: - if not node.args.get("group"): + # we can't push down predicates to select statements if they are referenced in + # multiple places. + if not node.args.get("group") and scope_ref_count[id(source)] < 2: nodes[table] = node return nodes |