sqlglot.optimizer.pushdown_projections
1from collections import defaultdict 2 3from sqlglot import alias, exp 4from sqlglot.optimizer.scope import Scope, traverse_scope 5 6# Sentinel value that means an outer query selecting ALL columns 7SELECT_ALL = object() 8 9# Selection to use if selection list is empty 10DEFAULT_SELECTION = lambda: alias("1", "_") 11 12 13def pushdown_projections(expression): 14 """ 15 Rewrite sqlglot AST to remove unused columns projections. 16 17 Example: 18 >>> import sqlglot 19 >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y" 20 >>> expression = sqlglot.parse_one(sql) 21 >>> pushdown_projections(expression).sql() 22 'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y' 23 24 Args: 25 expression (sqlglot.Expression): expression to optimize 26 Returns: 27 sqlglot.Expression: optimized expression 28 """ 29 # Map of Scope to all columns being selected by outer queries. 30 referenced_columns = defaultdict(set) 31 left_union = None 32 right_union = None 33 # We build the scope tree (which is traversed in DFS postorder), then iterate 34 # over the result in reverse order. This should ensure that the set of selected 35 # columns for a particular scope are completely build by the time we get to it. 36 for scope in reversed(traverse_scope(expression)): 37 parent_selections = referenced_columns.get(scope, {SELECT_ALL}) 38 39 if scope.expression.args.get("distinct"): 40 # We can't remove columns SELECT DISTINCT nor UNION DISTINCT 41 parent_selections = {SELECT_ALL} 42 43 if isinstance(scope.expression, exp.Union): 44 left_union, right_union = scope.union_scopes 45 referenced_columns[left_union] = parent_selections 46 referenced_columns[right_union] = parent_selections 47 48 if isinstance(scope.expression, exp.Select) and scope != right_union: 49 removed_indexes = _remove_unused_selections(scope, parent_selections) 50 # The left union is used for column names to select and if we remove columns from the left 51 # we need to also remove those same columns in the right that were at the same position 52 if scope is left_union: 53 _remove_indexed_selections(right_union, removed_indexes) 54 55 # Group columns by source name 56 selects = defaultdict(set) 57 for col in scope.columns: 58 table_name = col.table 59 col_name = col.name 60 selects[table_name].add(col_name) 61 62 # Push the selected columns down to the next scope 63 for name, (_, source) in scope.selected_sources.items(): 64 if isinstance(source, Scope): 65 columns = selects.get(name) or set() 66 referenced_columns[source].update(columns) 67 68 return expression 69 70 71def _remove_unused_selections(scope, parent_selections): 72 removed_indexes = [] 73 order = scope.expression.args.get("order") 74 75 if order: 76 # Assume columns without a qualified table are references to output columns 77 order_refs = {c.name for c in order.find_all(exp.Column) if not c.table} 78 else: 79 order_refs = set() 80 81 new_selections = [] 82 removed = False 83 for i, selection in enumerate(scope.selects): 84 if ( 85 SELECT_ALL in parent_selections 86 or selection.alias_or_name in parent_selections 87 or selection.alias_or_name in order_refs 88 ): 89 new_selections.append(selection) 90 else: 91 removed_indexes.append(i) 92 removed = True 93 94 # If there are no remaining selections, just select a single constant 95 if not new_selections: 96 new_selections.append(DEFAULT_SELECTION()) 97 98 scope.expression.set("expressions", new_selections) 99 if removed: 100 scope.clear_cache() 101 return removed_indexes 102 103 104def _remove_indexed_selections(scope, indexes_to_remove): 105 new_selections = [ 106 selection for i, selection in enumerate(scope.selects) if i not in indexes_to_remove 107 ] 108 if not new_selections: 109 new_selections.append(DEFAULT_SELECTION()) 110 scope.expression.set("expressions", new_selections)
def
DEFAULT_SELECTION():
11DEFAULT_SELECTION = lambda: alias("1", "_")
def
pushdown_projections(expression):
14def pushdown_projections(expression): 15 """ 16 Rewrite sqlglot AST to remove unused columns projections. 17 18 Example: 19 >>> import sqlglot 20 >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y" 21 >>> expression = sqlglot.parse_one(sql) 22 >>> pushdown_projections(expression).sql() 23 'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y' 24 25 Args: 26 expression (sqlglot.Expression): expression to optimize 27 Returns: 28 sqlglot.Expression: optimized expression 29 """ 30 # Map of Scope to all columns being selected by outer queries. 31 referenced_columns = defaultdict(set) 32 left_union = None 33 right_union = None 34 # We build the scope tree (which is traversed in DFS postorder), then iterate 35 # over the result in reverse order. This should ensure that the set of selected 36 # columns for a particular scope are completely build by the time we get to it. 37 for scope in reversed(traverse_scope(expression)): 38 parent_selections = referenced_columns.get(scope, {SELECT_ALL}) 39 40 if scope.expression.args.get("distinct"): 41 # We can't remove columns SELECT DISTINCT nor UNION DISTINCT 42 parent_selections = {SELECT_ALL} 43 44 if isinstance(scope.expression, exp.Union): 45 left_union, right_union = scope.union_scopes 46 referenced_columns[left_union] = parent_selections 47 referenced_columns[right_union] = parent_selections 48 49 if isinstance(scope.expression, exp.Select) and scope != right_union: 50 removed_indexes = _remove_unused_selections(scope, parent_selections) 51 # The left union is used for column names to select and if we remove columns from the left 52 # we need to also remove those same columns in the right that were at the same position 53 if scope is left_union: 54 _remove_indexed_selections(right_union, removed_indexes) 55 56 # Group columns by source name 57 selects = defaultdict(set) 58 for col in scope.columns: 59 table_name = col.table 60 col_name = col.name 61 selects[table_name].add(col_name) 62 63 # Push the selected columns down to the next scope 64 for name, (_, source) in scope.selected_sources.items(): 65 if isinstance(source, Scope): 66 columns = selects.get(name) or set() 67 referenced_columns[source].update(columns) 68 69 return expression
Rewrite sqlglot AST to remove unused columns projections.
Example:
>>> import sqlglot >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y" >>> expression = sqlglot.parse_one(sql) >>> pushdown_projections(expression).sql() 'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y'
Arguments:
- expression (sqlglot.Expression): expression to optimize
Returns:
sqlglot.Expression: optimized expression