diff options
Diffstat (limited to '')
-rw-r--r-- | sqlglot/optimizer/pushdown_predicates.py | 13 |
1 files changed, 8 insertions, 5 deletions
diff --git a/sqlglot/optimizer/pushdown_predicates.py b/sqlglot/optimizer/pushdown_predicates.py index f7348b5..10ff13a 100644 --- a/sqlglot/optimizer/pushdown_predicates.py +++ b/sqlglot/optimizer/pushdown_predicates.py @@ -4,7 +4,7 @@ from sqlglot.optimizer.scope import build_scope, find_in_scope from sqlglot.optimizer.simplify import simplify -def pushdown_predicates(expression): +def pushdown_predicates(expression, dialect=None): """ Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS @@ -36,7 +36,7 @@ def pushdown_predicates(expression): if isinstance(parent, exp.Join) and parent.side == "RIGHT": selected_sources = {k: (node, source)} break - pushdown(where.this, selected_sources, scope_ref_count) + pushdown(where.this, selected_sources, scope_ref_count, dialect) # joins should only pushdown into itself, not to other joins # so we limit the selected sources to only itself @@ -44,17 +44,20 @@ def pushdown_predicates(expression): name = join.alias_or_name if name in scope.selected_sources: pushdown( - join.args.get("on"), {name: scope.selected_sources[name]}, scope_ref_count + join.args.get("on"), + {name: scope.selected_sources[name]}, + scope_ref_count, + dialect, ) return expression -def pushdown(condition, sources, scope_ref_count): +def pushdown(condition, sources, scope_ref_count, dialect): if not condition: return - condition = condition.replace(simplify(condition)) + condition = condition.replace(simplify(condition, dialect=dialect)) cnf_like = normalized(condition) or not normalized(condition, dnf=True) predicates = list( |