diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2022-10-15 13:52:53 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2022-10-15 13:52:53 +0000 |
commit | 97d3673ec2d668050912aa6aea1816885ca6c5ab (patch) | |
tree | f391e30e039a3d22368e9696e171f759e104c765 /sqlglot/optimizer/pushdown_predicates.py | |
parent | Adding upstream version 6.3.1. (diff) | |
download | sqlglot-97d3673ec2d668050912aa6aea1816885ca6c5ab.tar.xz sqlglot-97d3673ec2d668050912aa6aea1816885ca6c5ab.zip |
Adding upstream version 7.1.3.upstream/7.1.3
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/optimizer/pushdown_predicates.py')
-rw-r--r-- | sqlglot/optimizer/pushdown_predicates.py | 19 |
1 files changed, 7 insertions, 12 deletions
diff --git a/sqlglot/optimizer/pushdown_predicates.py b/sqlglot/optimizer/pushdown_predicates.py index 9c8d71d..583d059 100644 --- a/sqlglot/optimizer/pushdown_predicates.py +++ b/sqlglot/optimizer/pushdown_predicates.py @@ -1,8 +1,6 @@ -from collections import defaultdict - from sqlglot import exp from sqlglot.optimizer.normalize import normalized -from sqlglot.optimizer.scope import traverse_scope +from sqlglot.optimizer.scope import build_scope from sqlglot.optimizer.simplify import simplify @@ -22,15 +20,10 @@ def pushdown_predicates(expression): Returns: sqlglot.Expression: optimized expression """ - scope_ref_count = defaultdict(lambda: 0) - scopes = traverse_scope(expression) - scopes.reverse() - - for scope in scopes: - for _, source in scope.selected_sources.values(): - scope_ref_count[id(source)] += 1 + root = build_scope(expression) + scope_ref_count = root.ref_count() - for scope in scopes: + for scope in reversed(list(root.traverse())): select = scope.expression where = select.args.get("where") if where: @@ -152,9 +145,11 @@ def nodes_for_predicate(predicate, sources, scope_ref_count): return {} nodes[table] = node elif isinstance(node, exp.Select) and len(tables) == 1: + # We can't push down window expressions + has_window_expression = any(select for select in node.selects if select.find(exp.Window)) # we can't push down predicates to select statements if they are referenced in # multiple places. - if not node.args.get("group") and scope_ref_count[id(source)] < 2: + if not node.args.get("group") and scope_ref_count[id(source)] < 2 and not has_window_expression: nodes[table] = node return nodes |