summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/pushdown_projections.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer/pushdown_projections.py')
-rw-r--r--sqlglot/optimizer/pushdown_projections.py76
1 files changed, 41 insertions, 35 deletions
diff --git a/sqlglot/optimizer/pushdown_projections.py b/sqlglot/optimizer/pushdown_projections.py
index 54c5021..3f360f9 100644
--- a/sqlglot/optimizer/pushdown_projections.py
+++ b/sqlglot/optimizer/pushdown_projections.py
@@ -1,7 +1,10 @@
from collections import defaultdict
from sqlglot import alias, exp
+from sqlglot.helper import flatten
+from sqlglot.optimizer.qualify_columns import Resolver
from sqlglot.optimizer.scope import Scope, traverse_scope
+from sqlglot.schema import ensure_schema
# Sentinel value that means an outer query selecting ALL columns
SELECT_ALL = object()
@@ -10,7 +13,7 @@ SELECT_ALL = object()
DEFAULT_SELECTION = lambda: alias("1", "_")
-def pushdown_projections(expression):
+def pushdown_projections(expression, schema=None):
"""
Rewrite sqlglot AST to remove unused columns projections.
@@ -27,9 +30,9 @@ def pushdown_projections(expression):
sqlglot.Expression: optimized expression
"""
# Map of Scope to all columns being selected by outer queries.
+ schema = ensure_schema(schema)
referenced_columns = defaultdict(set)
- left_union = None
- right_union = None
+
# We build the scope tree (which is traversed in DFS postorder), then iterate
# over the result in reverse order. This should ensure that the set of selected
# columns for a particular scope are completely build by the time we get to it.
@@ -41,16 +44,20 @@ def pushdown_projections(expression):
parent_selections = {SELECT_ALL}
if isinstance(scope.expression, exp.Union):
- left_union, right_union = scope.union_scopes
- referenced_columns[left_union] = parent_selections
- referenced_columns[right_union] = parent_selections
+ left, right = scope.union_scopes
+ referenced_columns[left] = parent_selections
+
+ if any(select.is_star for select in right.selects):
+ referenced_columns[right] = parent_selections
+ elif not any(select.is_star for select in left.selects):
+ referenced_columns[right] = [
+ right.selects[i].alias_or_name
+ for i, select in enumerate(left.selects)
+ if SELECT_ALL in parent_selections or select.alias_or_name in parent_selections
+ ]
- if isinstance(scope.expression, exp.Select) and scope != right_union:
- removed_indexes = _remove_unused_selections(scope, parent_selections)
- # The left union is used for column names to select and if we remove columns from the left
- # we need to also remove those same columns in the right that were at the same position
- if scope is left_union:
- _remove_indexed_selections(right_union, removed_indexes)
+ if isinstance(scope.expression, exp.Select):
+ _remove_unused_selections(scope, parent_selections, schema)
# Group columns by source name
selects = defaultdict(set)
@@ -68,8 +75,7 @@ def pushdown_projections(expression):
return expression
-def _remove_unused_selections(scope, parent_selections):
- removed_indexes = []
+def _remove_unused_selections(scope, parent_selections, schema):
order = scope.expression.args.get("order")
if order:
@@ -78,33 +84,33 @@ def _remove_unused_selections(scope, parent_selections):
else:
order_refs = set()
- new_selections = []
+ new_selections = defaultdict(list)
removed = False
- for i, selection in enumerate(scope.selects):
- if (
- SELECT_ALL in parent_selections
- or selection.alias_or_name in parent_selections
- or selection.alias_or_name in order_refs
- ):
- new_selections.append(selection)
+ star = False
+ for selection in scope.selects:
+ name = selection.alias_or_name
+
+ if SELECT_ALL in parent_selections or name in parent_selections or name in order_refs:
+ new_selections[name].append(selection)
else:
- removed_indexes.append(i)
+ if selection.is_star:
+ star = True
removed = True
+ if star:
+ resolver = Resolver(scope, schema)
+
+ for name in sorted(parent_selections):
+ if name not in new_selections:
+ new_selections[name].append(
+ alias(exp.column(name, table=resolver.get_table(name)), name)
+ )
+
# If there are no remaining selections, just select a single constant
if not new_selections:
- new_selections.append(DEFAULT_SELECTION())
+ new_selections[""].append(DEFAULT_SELECTION())
+
+ scope.expression.select(*flatten(new_selections.values()), append=False, copy=False)
- scope.expression.set("expressions", new_selections)
if removed:
scope.clear_cache()
- return removed_indexes
-
-
-def _remove_indexed_selections(scope, indexes_to_remove):
- new_selections = [
- selection for i, selection in enumerate(scope.selects) if i not in indexes_to_remove
- ]
- if not new_selections:
- new_selections.append(DEFAULT_SELECTION())
- scope.expression.set("expressions", new_selections)