diff options
Diffstat (limited to 'sqlglot/optimizer/pushdown_projections.py')
-rw-r--r-- | sqlglot/optimizer/pushdown_projections.py | 76 |
1 files changed, 41 insertions, 35 deletions
diff --git a/sqlglot/optimizer/pushdown_projections.py b/sqlglot/optimizer/pushdown_projections.py index 54c5021..3f360f9 100644 --- a/sqlglot/optimizer/pushdown_projections.py +++ b/sqlglot/optimizer/pushdown_projections.py @@ -1,7 +1,10 @@ from collections import defaultdict from sqlglot import alias, exp +from sqlglot.helper import flatten +from sqlglot.optimizer.qualify_columns import Resolver from sqlglot.optimizer.scope import Scope, traverse_scope +from sqlglot.schema import ensure_schema # Sentinel value that means an outer query selecting ALL columns SELECT_ALL = object() @@ -10,7 +13,7 @@ SELECT_ALL = object() DEFAULT_SELECTION = lambda: alias("1", "_") -def pushdown_projections(expression): +def pushdown_projections(expression, schema=None): """ Rewrite sqlglot AST to remove unused columns projections. @@ -27,9 +30,9 @@ def pushdown_projections(expression): sqlglot.Expression: optimized expression """ # Map of Scope to all columns being selected by outer queries. + schema = ensure_schema(schema) referenced_columns = defaultdict(set) - left_union = None - right_union = None + # We build the scope tree (which is traversed in DFS postorder), then iterate # over the result in reverse order. This should ensure that the set of selected # columns for a particular scope are completely build by the time we get to it. @@ -41,16 +44,20 @@ def pushdown_projections(expression): parent_selections = {SELECT_ALL} if isinstance(scope.expression, exp.Union): - left_union, right_union = scope.union_scopes - referenced_columns[left_union] = parent_selections - referenced_columns[right_union] = parent_selections + left, right = scope.union_scopes + referenced_columns[left] = parent_selections + + if any(select.is_star for select in right.selects): + referenced_columns[right] = parent_selections + elif not any(select.is_star for select in left.selects): + referenced_columns[right] = [ + right.selects[i].alias_or_name + for i, select in enumerate(left.selects) + if SELECT_ALL in parent_selections or select.alias_or_name in parent_selections + ] - if isinstance(scope.expression, exp.Select) and scope != right_union: - removed_indexes = _remove_unused_selections(scope, parent_selections) - # The left union is used for column names to select and if we remove columns from the left - # we need to also remove those same columns in the right that were at the same position - if scope is left_union: - _remove_indexed_selections(right_union, removed_indexes) + if isinstance(scope.expression, exp.Select): + _remove_unused_selections(scope, parent_selections, schema) # Group columns by source name selects = defaultdict(set) @@ -68,8 +75,7 @@ def pushdown_projections(expression): return expression -def _remove_unused_selections(scope, parent_selections): - removed_indexes = [] +def _remove_unused_selections(scope, parent_selections, schema): order = scope.expression.args.get("order") if order: @@ -78,33 +84,33 @@ def _remove_unused_selections(scope, parent_selections): else: order_refs = set() - new_selections = [] + new_selections = defaultdict(list) removed = False - for i, selection in enumerate(scope.selects): - if ( - SELECT_ALL in parent_selections - or selection.alias_or_name in parent_selections - or selection.alias_or_name in order_refs - ): - new_selections.append(selection) + star = False + for selection in scope.selects: + name = selection.alias_or_name + + if SELECT_ALL in parent_selections or name in parent_selections or name in order_refs: + new_selections[name].append(selection) else: - removed_indexes.append(i) + if selection.is_star: + star = True removed = True + if star: + resolver = Resolver(scope, schema) + + for name in sorted(parent_selections): + if name not in new_selections: + new_selections[name].append( + alias(exp.column(name, table=resolver.get_table(name)), name) + ) + # If there are no remaining selections, just select a single constant if not new_selections: - new_selections.append(DEFAULT_SELECTION()) + new_selections[""].append(DEFAULT_SELECTION()) + + scope.expression.select(*flatten(new_selections.values()), append=False, copy=False) - scope.expression.set("expressions", new_selections) if removed: scope.clear_cache() - return removed_indexes - - -def _remove_indexed_selections(scope, indexes_to_remove): - new_selections = [ - selection for i, selection in enumerate(scope.selects) if i not in indexes_to_remove - ] - if not new_selections: - new_selections.append(DEFAULT_SELECTION()) - scope.expression.set("expressions", new_selections) |