sqlglot.optimizer.pushdown_projections
1from collections import defaultdict 2 3from sqlglot import alias, exp 4from sqlglot.helper import flatten 5from sqlglot.optimizer.qualify_columns import Resolver 6from sqlglot.optimizer.scope import Scope, traverse_scope 7from sqlglot.schema import ensure_schema 8 9# Sentinel value that means an outer query selecting ALL columns 10SELECT_ALL = object() 11 12# Selection to use if selection list is empty 13DEFAULT_SELECTION = lambda: alias("1", "_") 14 15 16def pushdown_projections(expression, schema=None, remove_unused_selections=True): 17 """ 18 Rewrite sqlglot AST to remove unused columns projections. 19 20 Example: 21 >>> import sqlglot 22 >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y" 23 >>> expression = sqlglot.parse_one(sql) 24 >>> pushdown_projections(expression).sql() 25 'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y' 26 27 Args: 28 expression (sqlglot.Expression): expression to optimize 29 remove_unused_selections (bool): remove selects that are unused 30 Returns: 31 sqlglot.Expression: optimized expression 32 """ 33 # Map of Scope to all columns being selected by outer queries. 34 schema = ensure_schema(schema) 35 referenced_columns = defaultdict(set) 36 37 # We build the scope tree (which is traversed in DFS postorder), then iterate 38 # over the result in reverse order. This should ensure that the set of selected 39 # columns for a particular scope are completely build by the time we get to it. 40 for scope in reversed(traverse_scope(expression)): 41 parent_selections = referenced_columns.get(scope, {SELECT_ALL}) 42 43 if scope.expression.args.get("distinct"): 44 # We can't remove columns SELECT DISTINCT nor UNION DISTINCT 45 parent_selections = {SELECT_ALL} 46 47 if isinstance(scope.expression, exp.Union): 48 left, right = scope.union_scopes 49 referenced_columns[left] = parent_selections 50 51 if any(select.is_star for select in right.selects): 52 referenced_columns[right] = parent_selections 53 elif not any(select.is_star for select in left.selects): 54 referenced_columns[right] = [ 55 right.selects[i].alias_or_name 56 for i, select in enumerate(left.selects) 57 if SELECT_ALL in parent_selections or select.alias_or_name in parent_selections 58 ] 59 60 if isinstance(scope.expression, exp.Select): 61 if remove_unused_selections: 62 _remove_unused_selections(scope, parent_selections, schema) 63 64 # Group columns by source name 65 selects = defaultdict(set) 66 for col in scope.columns: 67 table_name = col.table 68 col_name = col.name 69 selects[table_name].add(col_name) 70 71 # Push the selected columns down to the next scope 72 for name, (_, source) in scope.selected_sources.items(): 73 if isinstance(source, Scope): 74 columns = selects.get(name) or set() 75 referenced_columns[source].update(columns) 76 77 return expression 78 79 80def _remove_unused_selections(scope, parent_selections, schema): 81 order = scope.expression.args.get("order") 82 83 if order: 84 # Assume columns without a qualified table are references to output columns 85 order_refs = {c.name for c in order.find_all(exp.Column) if not c.table} 86 else: 87 order_refs = set() 88 89 new_selections = defaultdict(list) 90 removed = False 91 star = False 92 for selection in scope.selects: 93 name = selection.alias_or_name 94 95 if SELECT_ALL in parent_selections or name in parent_selections or name in order_refs: 96 new_selections[name].append(selection) 97 else: 98 if selection.is_star: 99 star = True 100 removed = True 101 102 if star: 103 resolver = Resolver(scope, schema) 104 105 for name in sorted(parent_selections): 106 if name not in new_selections: 107 new_selections[name].append( 108 alias(exp.column(name, table=resolver.get_table(name)), name) 109 ) 110 111 # If there are no remaining selections, just select a single constant 112 if not new_selections: 113 new_selections[""].append(DEFAULT_SELECTION()) 114 115 scope.expression.select(*flatten(new_selections.values()), append=False, copy=False) 116 117 if removed: 118 scope.clear_cache()
def
DEFAULT_SELECTION():
14DEFAULT_SELECTION = lambda: alias("1", "_")
def
pushdown_projections(expression, schema=None, remove_unused_selections=True):
17def pushdown_projections(expression, schema=None, remove_unused_selections=True): 18 """ 19 Rewrite sqlglot AST to remove unused columns projections. 20 21 Example: 22 >>> import sqlglot 23 >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y" 24 >>> expression = sqlglot.parse_one(sql) 25 >>> pushdown_projections(expression).sql() 26 'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y' 27 28 Args: 29 expression (sqlglot.Expression): expression to optimize 30 remove_unused_selections (bool): remove selects that are unused 31 Returns: 32 sqlglot.Expression: optimized expression 33 """ 34 # Map of Scope to all columns being selected by outer queries. 35 schema = ensure_schema(schema) 36 referenced_columns = defaultdict(set) 37 38 # We build the scope tree (which is traversed in DFS postorder), then iterate 39 # over the result in reverse order. This should ensure that the set of selected 40 # columns for a particular scope are completely build by the time we get to it. 41 for scope in reversed(traverse_scope(expression)): 42 parent_selections = referenced_columns.get(scope, {SELECT_ALL}) 43 44 if scope.expression.args.get("distinct"): 45 # We can't remove columns SELECT DISTINCT nor UNION DISTINCT 46 parent_selections = {SELECT_ALL} 47 48 if isinstance(scope.expression, exp.Union): 49 left, right = scope.union_scopes 50 referenced_columns[left] = parent_selections 51 52 if any(select.is_star for select in right.selects): 53 referenced_columns[right] = parent_selections 54 elif not any(select.is_star for select in left.selects): 55 referenced_columns[right] = [ 56 right.selects[i].alias_or_name 57 for i, select in enumerate(left.selects) 58 if SELECT_ALL in parent_selections or select.alias_or_name in parent_selections 59 ] 60 61 if isinstance(scope.expression, exp.Select): 62 if remove_unused_selections: 63 _remove_unused_selections(scope, parent_selections, schema) 64 65 # Group columns by source name 66 selects = defaultdict(set) 67 for col in scope.columns: 68 table_name = col.table 69 col_name = col.name 70 selects[table_name].add(col_name) 71 72 # Push the selected columns down to the next scope 73 for name, (_, source) in scope.selected_sources.items(): 74 if isinstance(source, Scope): 75 columns = selects.get(name) or set() 76 referenced_columns[source].update(columns) 77 78 return expression
Rewrite sqlglot AST to remove unused columns projections.
Example:
>>> import sqlglot >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y" >>> expression = sqlglot.parse_one(sql) >>> pushdown_projections(expression).sql() 'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y'
Arguments:
- expression (sqlglot.Expression): expression to optimize
- remove_unused_selections (bool): remove selects that are unused
Returns:
sqlglot.Expression: optimized expression