diff options
Diffstat (limited to 'sqlglot/optimizer/pushdown_projections.py')
-rw-r--r-- | sqlglot/optimizer/pushdown_projections.py | 19 |
1 files changed, 14 insertions, 5 deletions
diff --git a/sqlglot/optimizer/pushdown_projections.py b/sqlglot/optimizer/pushdown_projections.py index c81fd00..b51601f 100644 --- a/sqlglot/optimizer/pushdown_projections.py +++ b/sqlglot/optimizer/pushdown_projections.py @@ -31,6 +31,7 @@ def pushdown_projections(expression, schema=None, remove_unused_selections=True) """ # Map of Scope to all columns being selected by outer queries. schema = ensure_schema(schema) + source_column_alias_count = {} referenced_columns = defaultdict(set) # We build the scope tree (which is traversed in DFS postorder), then iterate @@ -38,8 +39,9 @@ def pushdown_projections(expression, schema=None, remove_unused_selections=True) # columns for a particular scope are completely build by the time we get to it. for scope in reversed(traverse_scope(expression)): parent_selections = referenced_columns.get(scope, {SELECT_ALL}) + alias_count = source_column_alias_count.get(scope, 0) - if scope.expression.args.get("distinct") or scope.parent and scope.parent.pivots: + if scope.expression.args.get("distinct") or (scope.parent and scope.parent.pivots): # We can't remove columns SELECT DISTINCT nor UNION DISTINCT. The same holds if # we select from a pivoted source in the parent scope. parent_selections = {SELECT_ALL} @@ -59,7 +61,7 @@ def pushdown_projections(expression, schema=None, remove_unused_selections=True) if isinstance(scope.expression, exp.Select): if remove_unused_selections: - _remove_unused_selections(scope, parent_selections, schema) + _remove_unused_selections(scope, parent_selections, schema, alias_count) if scope.expression.is_star: continue @@ -72,15 +74,19 @@ def pushdown_projections(expression, schema=None, remove_unused_selections=True) selects[table_name].add(col_name) # Push the selected columns down to the next scope - for name, (_, source) in scope.selected_sources.items(): + for name, (node, source) in scope.selected_sources.items(): if isinstance(source, Scope): columns = selects.get(name) or set() referenced_columns[source].update(columns) + column_aliases = node.alias_column_names + if column_aliases: + source_column_alias_count[source] = len(column_aliases) + return expression -def _remove_unused_selections(scope, parent_selections, schema): +def _remove_unused_selections(scope, parent_selections, schema, alias_count): order = scope.expression.args.get("order") if order: @@ -93,11 +99,14 @@ def _remove_unused_selections(scope, parent_selections, schema): removed = False star = False + select_all = SELECT_ALL in parent_selections + for selection in scope.expression.selects: name = selection.alias_or_name - if SELECT_ALL in parent_selections or name in parent_selections or name in order_refs: + if select_all or name in parent_selections or name in order_refs or alias_count > 0: new_selections.append(selection) + alias_count -= 1 else: if selection.is_star: star = True |