summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/pushdown_projections.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer/pushdown_projections.py')
-rw-r--r--sqlglot/optimizer/pushdown_projections.py35
1 files changed, 27 insertions, 8 deletions
diff --git a/sqlglot/optimizer/pushdown_projections.py b/sqlglot/optimizer/pushdown_projections.py
index 5584830..5820851 100644
--- a/sqlglot/optimizer/pushdown_projections.py
+++ b/sqlglot/optimizer/pushdown_projections.py
@@ -6,6 +6,9 @@ from sqlglot.optimizer.scope import Scope, traverse_scope
# Sentinel value that means an outer query selecting ALL columns
SELECT_ALL = object()
+# SELECTION TO USE IF SELECTION LIST IS EMPTY
+DEFAULT_SELECTION = alias("1", "_")
+
def pushdown_projections(expression):
"""
@@ -25,7 +28,8 @@ def pushdown_projections(expression):
"""
# Map of Scope to all columns being selected by outer queries.
referenced_columns = defaultdict(set)
-
+ left_union = None
+ right_union = None
# We build the scope tree (which is traversed in DFS postorder), then iterate
# over the result in reverse order. This should ensure that the set of selected
# columns for a particular scope are completely build by the time we get to it.
@@ -37,12 +41,16 @@ def pushdown_projections(expression):
parent_selections = {SELECT_ALL}
if isinstance(scope.expression, exp.Union):
- left, right = scope.union_scopes
- referenced_columns[left] = parent_selections
- referenced_columns[right] = parent_selections
+ left_union, right_union = scope.union_scopes
+ referenced_columns[left_union] = parent_selections
+ referenced_columns[right_union] = parent_selections
- if isinstance(scope.expression, exp.Select):
- _remove_unused_selections(scope, parent_selections)
+ if isinstance(scope.expression, exp.Select) and scope != right_union:
+ removed_indexes = _remove_unused_selections(scope, parent_selections)
+ # The left union is used for column names to select and if we remove columns from the left
+ # we need to also remove those same columns in the right that were at the same position
+ if scope is left_union:
+ _remove_indexed_selections(right_union, removed_indexes)
# Group columns by source name
selects = defaultdict(set)
@@ -61,6 +69,7 @@ def pushdown_projections(expression):
def _remove_unused_selections(scope, parent_selections):
+ removed_indexes = []
order = scope.expression.args.get("order")
if order:
@@ -70,16 +79,26 @@ def _remove_unused_selections(scope, parent_selections):
order_refs = set()
new_selections = []
- for selection in scope.selects:
+ for i, selection in enumerate(scope.selects):
if (
SELECT_ALL in parent_selections
or selection.alias_or_name in parent_selections
or selection.alias_or_name in order_refs
):
new_selections.append(selection)
+ else:
+ removed_indexes.append(i)
# If there are no remaining selections, just select a single constant
if not new_selections:
- new_selections.append(alias("1", "_"))
+ new_selections.append(DEFAULT_SELECTION)
+
+ scope.expression.set("expressions", new_selections)
+ return removed_indexes
+
+def _remove_indexed_selections(scope, indexes_to_remove):
+ new_selections = [selection for i, selection in enumerate(scope.selects) if i not in indexes_to_remove]
+ if not new_selections:
+ new_selections.append(DEFAULT_SELECTION)
scope.expression.set("expressions", new_selections)