diff options
Diffstat (limited to 'sqlglot/optimizer/pushdown_predicates.py')
-rw-r--r-- | sqlglot/optimizer/pushdown_predicates.py | 28 |
1 files changed, 22 insertions, 6 deletions
diff --git a/sqlglot/optimizer/pushdown_predicates.py b/sqlglot/optimizer/pushdown_predicates.py index 583d059..6364f65 100644 --- a/sqlglot/optimizer/pushdown_predicates.py +++ b/sqlglot/optimizer/pushdown_predicates.py @@ -45,7 +45,11 @@ def pushdown(condition, sources, scope_ref_count): condition = condition.replace(simplify(condition)) cnf_like = normalized(condition) or not normalized(condition, dnf=True) - predicates = list(condition.flatten() if isinstance(condition, exp.And if cnf_like else exp.Or) else [condition]) + predicates = list( + condition.flatten() + if isinstance(condition, exp.And if cnf_like else exp.Or) + else [condition] + ) if cnf_like: pushdown_cnf(predicates, sources, scope_ref_count) @@ -104,11 +108,17 @@ def pushdown_dnf(predicates, scope, scope_ref_count): for column in predicate.find_all(exp.Column): if column.table == table: condition = column.find_ancestor(exp.Condition) - predicate_condition = exp.and_(predicate_condition, condition) if predicate_condition else condition + predicate_condition = ( + exp.and_(predicate_condition, condition) + if predicate_condition + else condition + ) if predicate_condition: conditions[table] = ( - exp.or_(conditions[table], predicate_condition) if table in conditions else predicate_condition + exp.or_(conditions[table], predicate_condition) + if table in conditions + else predicate_condition ) for name, node in nodes.items(): @@ -146,10 +156,16 @@ def nodes_for_predicate(predicate, sources, scope_ref_count): nodes[table] = node elif isinstance(node, exp.Select) and len(tables) == 1: # We can't push down window expressions - has_window_expression = any(select for select in node.selects if select.find(exp.Window)) + has_window_expression = any( + select for select in node.selects if select.find(exp.Window) + ) # we can't push down predicates to select statements if they are referenced in # multiple places. - if not node.args.get("group") and scope_ref_count[id(source)] < 2 and not has_window_expression: + if ( + not node.args.get("group") + and scope_ref_count[id(source)] < 2 + and not has_window_expression + ): nodes[table] = node return nodes @@ -165,7 +181,7 @@ def replace_aliases(source, predicate): def _replace_alias(column): if isinstance(column, exp.Column) and column.name in aliases: - return aliases[column.name] + return aliases[column.name].copy() return column return predicate.transform(_replace_alias) |