summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/pushdown_predicates.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer/pushdown_predicates.py')
-rw-r--r--sqlglot/optimizer/pushdown_predicates.py28
1 files changed, 22 insertions, 6 deletions
diff --git a/sqlglot/optimizer/pushdown_predicates.py b/sqlglot/optimizer/pushdown_predicates.py
index 583d059..6364f65 100644
--- a/sqlglot/optimizer/pushdown_predicates.py
+++ b/sqlglot/optimizer/pushdown_predicates.py
@@ -45,7 +45,11 @@ def pushdown(condition, sources, scope_ref_count):
condition = condition.replace(simplify(condition))
cnf_like = normalized(condition) or not normalized(condition, dnf=True)
- predicates = list(condition.flatten() if isinstance(condition, exp.And if cnf_like else exp.Or) else [condition])
+ predicates = list(
+ condition.flatten()
+ if isinstance(condition, exp.And if cnf_like else exp.Or)
+ else [condition]
+ )
if cnf_like:
pushdown_cnf(predicates, sources, scope_ref_count)
@@ -104,11 +108,17 @@ def pushdown_dnf(predicates, scope, scope_ref_count):
for column in predicate.find_all(exp.Column):
if column.table == table:
condition = column.find_ancestor(exp.Condition)
- predicate_condition = exp.and_(predicate_condition, condition) if predicate_condition else condition
+ predicate_condition = (
+ exp.and_(predicate_condition, condition)
+ if predicate_condition
+ else condition
+ )
if predicate_condition:
conditions[table] = (
- exp.or_(conditions[table], predicate_condition) if table in conditions else predicate_condition
+ exp.or_(conditions[table], predicate_condition)
+ if table in conditions
+ else predicate_condition
)
for name, node in nodes.items():
@@ -146,10 +156,16 @@ def nodes_for_predicate(predicate, sources, scope_ref_count):
nodes[table] = node
elif isinstance(node, exp.Select) and len(tables) == 1:
# We can't push down window expressions
- has_window_expression = any(select for select in node.selects if select.find(exp.Window))
+ has_window_expression = any(
+ select for select in node.selects if select.find(exp.Window)
+ )
# we can't push down predicates to select statements if they are referenced in
# multiple places.
- if not node.args.get("group") and scope_ref_count[id(source)] < 2 and not has_window_expression:
+ if (
+ not node.args.get("group")
+ and scope_ref_count[id(source)] < 2
+ and not has_window_expression
+ ):
nodes[table] = node
return nodes
@@ -165,7 +181,7 @@ def replace_aliases(source, predicate):
def _replace_alias(column):
if isinstance(column, exp.Column) and column.name in aliases:
- return aliases[column.name]
+ return aliases[column.name].copy()
return column
return predicate.transform(_replace_alias)