sqlglot.optimizer.pushdown_projections
1from collections import defaultdict 2 3from sqlglot import alias, exp 4from sqlglot.optimizer.qualify_columns import Resolver 5from sqlglot.optimizer.scope import Scope, traverse_scope 6from sqlglot.schema import ensure_schema 7 8# Sentinel value that means an outer query selecting ALL columns 9SELECT_ALL = object() 10 11# Selection to use if selection list is empty 12DEFAULT_SELECTION = lambda: alias("1", "_") 13 14 15def pushdown_projections(expression, schema=None, remove_unused_selections=True): 16 """ 17 Rewrite sqlglot AST to remove unused columns projections. 18 19 Example: 20 >>> import sqlglot 21 >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y" 22 >>> expression = sqlglot.parse_one(sql) 23 >>> pushdown_projections(expression).sql() 24 'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y' 25 26 Args: 27 expression (sqlglot.Expression): expression to optimize 28 remove_unused_selections (bool): remove selects that are unused 29 Returns: 30 sqlglot.Expression: optimized expression 31 """ 32 # Map of Scope to all columns being selected by outer queries. 33 schema = ensure_schema(schema) 34 source_column_alias_count = {} 35 referenced_columns = defaultdict(set) 36 37 # We build the scope tree (which is traversed in DFS postorder), then iterate 38 # over the result in reverse order. This should ensure that the set of selected 39 # columns for a particular scope are completely build by the time we get to it. 40 for scope in reversed(traverse_scope(expression)): 41 parent_selections = referenced_columns.get(scope, {SELECT_ALL}) 42 alias_count = source_column_alias_count.get(scope, 0) 43 44 if scope.expression.args.get("distinct") or (scope.parent and scope.parent.pivots): 45 # We can't remove columns SELECT DISTINCT nor UNION DISTINCT. The same holds if 46 # we select from a pivoted source in the parent scope. 47 parent_selections = {SELECT_ALL} 48 49 if isinstance(scope.expression, exp.Union): 50 left, right = scope.union_scopes 51 referenced_columns[left] = parent_selections 52 53 if any(select.is_star for select in right.expression.selects): 54 referenced_columns[right] = parent_selections 55 elif not any(select.is_star for select in left.expression.selects): 56 referenced_columns[right] = [ 57 right.expression.selects[i].alias_or_name 58 for i, select in enumerate(left.expression.selects) 59 if SELECT_ALL in parent_selections or select.alias_or_name in parent_selections 60 ] 61 62 if isinstance(scope.expression, exp.Select): 63 if remove_unused_selections: 64 _remove_unused_selections(scope, parent_selections, schema, alias_count) 65 66 if scope.expression.is_star: 67 continue 68 69 # Group columns by source name 70 selects = defaultdict(set) 71 for col in scope.columns: 72 table_name = col.table 73 col_name = col.name 74 selects[table_name].add(col_name) 75 76 # Push the selected columns down to the next scope 77 for name, (node, source) in scope.selected_sources.items(): 78 if isinstance(source, Scope): 79 columns = selects.get(name) or set() 80 referenced_columns[source].update(columns) 81 82 column_aliases = node.alias_column_names 83 if column_aliases: 84 source_column_alias_count[source] = len(column_aliases) 85 86 return expression 87 88 89def _remove_unused_selections(scope, parent_selections, schema, alias_count): 90 order = scope.expression.args.get("order") 91 92 if order: 93 # Assume columns without a qualified table are references to output columns 94 order_refs = {c.name for c in order.find_all(exp.Column) if not c.table} 95 else: 96 order_refs = set() 97 98 new_selections = [] 99 removed = False 100 star = False 101 102 select_all = SELECT_ALL in parent_selections 103 104 for selection in scope.expression.selects: 105 name = selection.alias_or_name 106 107 if select_all or name in parent_selections or name in order_refs or alias_count > 0: 108 new_selections.append(selection) 109 alias_count -= 1 110 else: 111 if selection.is_star: 112 star = True 113 removed = True 114 115 if star: 116 resolver = Resolver(scope, schema) 117 names = {s.alias_or_name for s in new_selections} 118 119 for name in sorted(parent_selections): 120 if name not in names: 121 new_selections.append( 122 alias(exp.column(name, table=resolver.get_table(name)), name, copy=False) 123 ) 124 125 # If there are no remaining selections, just select a single constant 126 if not new_selections: 127 new_selections.append(DEFAULT_SELECTION()) 128 129 scope.expression.select(*new_selections, append=False, copy=False) 130 131 if removed: 132 scope.clear_cache()
SELECT_ALL =
<object object>
def
DEFAULT_SELECTION():
13DEFAULT_SELECTION = lambda: alias("1", "_")
def
pushdown_projections(expression, schema=None, remove_unused_selections=True):
16def pushdown_projections(expression, schema=None, remove_unused_selections=True): 17 """ 18 Rewrite sqlglot AST to remove unused columns projections. 19 20 Example: 21 >>> import sqlglot 22 >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y" 23 >>> expression = sqlglot.parse_one(sql) 24 >>> pushdown_projections(expression).sql() 25 'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y' 26 27 Args: 28 expression (sqlglot.Expression): expression to optimize 29 remove_unused_selections (bool): remove selects that are unused 30 Returns: 31 sqlglot.Expression: optimized expression 32 """ 33 # Map of Scope to all columns being selected by outer queries. 34 schema = ensure_schema(schema) 35 source_column_alias_count = {} 36 referenced_columns = defaultdict(set) 37 38 # We build the scope tree (which is traversed in DFS postorder), then iterate 39 # over the result in reverse order. This should ensure that the set of selected 40 # columns for a particular scope are completely build by the time we get to it. 41 for scope in reversed(traverse_scope(expression)): 42 parent_selections = referenced_columns.get(scope, {SELECT_ALL}) 43 alias_count = source_column_alias_count.get(scope, 0) 44 45 if scope.expression.args.get("distinct") or (scope.parent and scope.parent.pivots): 46 # We can't remove columns SELECT DISTINCT nor UNION DISTINCT. The same holds if 47 # we select from a pivoted source in the parent scope. 48 parent_selections = {SELECT_ALL} 49 50 if isinstance(scope.expression, exp.Union): 51 left, right = scope.union_scopes 52 referenced_columns[left] = parent_selections 53 54 if any(select.is_star for select in right.expression.selects): 55 referenced_columns[right] = parent_selections 56 elif not any(select.is_star for select in left.expression.selects): 57 referenced_columns[right] = [ 58 right.expression.selects[i].alias_or_name 59 for i, select in enumerate(left.expression.selects) 60 if SELECT_ALL in parent_selections or select.alias_or_name in parent_selections 61 ] 62 63 if isinstance(scope.expression, exp.Select): 64 if remove_unused_selections: 65 _remove_unused_selections(scope, parent_selections, schema, alias_count) 66 67 if scope.expression.is_star: 68 continue 69 70 # Group columns by source name 71 selects = defaultdict(set) 72 for col in scope.columns: 73 table_name = col.table 74 col_name = col.name 75 selects[table_name].add(col_name) 76 77 # Push the selected columns down to the next scope 78 for name, (node, source) in scope.selected_sources.items(): 79 if isinstance(source, Scope): 80 columns = selects.get(name) or set() 81 referenced_columns[source].update(columns) 82 83 column_aliases = node.alias_column_names 84 if column_aliases: 85 source_column_alias_count[source] = len(column_aliases) 86 87 return expression
Rewrite sqlglot AST to remove unused columns projections.
Example:
>>> import sqlglot >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y" >>> expression = sqlglot.parse_one(sql) >>> pushdown_projections(expression).sql() 'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y'
Arguments:
- expression (sqlglot.Expression): expression to optimize
- remove_unused_selections (bool): remove selects that are unused
Returns:
sqlglot.Expression: optimized expression