Edit on GitHub

sqlglot.optimizer.pushdown_projections

  1from collections import defaultdict
  2
  3from sqlglot import alias, exp
  4from sqlglot.optimizer.qualify_columns import Resolver
  5from sqlglot.optimizer.scope import Scope, traverse_scope
  6from sqlglot.schema import ensure_schema
  7
  8# Sentinel value that means an outer query selecting ALL columns
  9SELECT_ALL = object()
 10
 11# Selection to use if selection list is empty
 12DEFAULT_SELECTION = lambda: alias("1", "_")
 13
 14
 15def pushdown_projections(expression, schema=None, remove_unused_selections=True):
 16    """
 17    Rewrite sqlglot AST to remove unused columns projections.
 18
 19    Example:
 20        >>> import sqlglot
 21        >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y"
 22        >>> expression = sqlglot.parse_one(sql)
 23        >>> pushdown_projections(expression).sql()
 24        'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y'
 25
 26    Args:
 27        expression (sqlglot.Expression): expression to optimize
 28        remove_unused_selections (bool): remove selects that are unused
 29    Returns:
 30        sqlglot.Expression: optimized expression
 31    """
 32    # Map of Scope to all columns being selected by outer queries.
 33    schema = ensure_schema(schema)
 34    source_column_alias_count = {}
 35    referenced_columns = defaultdict(set)
 36
 37    # We build the scope tree (which is traversed in DFS postorder), then iterate
 38    # over the result in reverse order. This should ensure that the set of selected
 39    # columns for a particular scope are completely build by the time we get to it.
 40    for scope in reversed(traverse_scope(expression)):
 41        parent_selections = referenced_columns.get(scope, {SELECT_ALL})
 42        alias_count = source_column_alias_count.get(scope, 0)
 43
 44        if scope.expression.args.get("distinct") or (scope.parent and scope.parent.pivots):
 45            # We can't remove columns SELECT DISTINCT nor UNION DISTINCT. The same holds if
 46            # we select from a pivoted source in the parent scope.
 47            parent_selections = {SELECT_ALL}
 48
 49        if isinstance(scope.expression, exp.Union):
 50            left, right = scope.union_scopes
 51            referenced_columns[left] = parent_selections
 52
 53            if any(select.is_star for select in right.expression.selects):
 54                referenced_columns[right] = parent_selections
 55            elif not any(select.is_star for select in left.expression.selects):
 56                referenced_columns[right] = [
 57                    right.expression.selects[i].alias_or_name
 58                    for i, select in enumerate(left.expression.selects)
 59                    if SELECT_ALL in parent_selections or select.alias_or_name in parent_selections
 60                ]
 61
 62        if isinstance(scope.expression, exp.Select):
 63            if remove_unused_selections:
 64                _remove_unused_selections(scope, parent_selections, schema, alias_count)
 65
 66            if scope.expression.is_star:
 67                continue
 68
 69            # Group columns by source name
 70            selects = defaultdict(set)
 71            for col in scope.columns:
 72                table_name = col.table
 73                col_name = col.name
 74                selects[table_name].add(col_name)
 75
 76            # Push the selected columns down to the next scope
 77            for name, (node, source) in scope.selected_sources.items():
 78                if isinstance(source, Scope):
 79                    columns = selects.get(name) or set()
 80                    referenced_columns[source].update(columns)
 81
 82                column_aliases = node.alias_column_names
 83                if column_aliases:
 84                    source_column_alias_count[source] = len(column_aliases)
 85
 86    return expression
 87
 88
 89def _remove_unused_selections(scope, parent_selections, schema, alias_count):
 90    order = scope.expression.args.get("order")
 91
 92    if order:
 93        # Assume columns without a qualified table are references to output columns
 94        order_refs = {c.name for c in order.find_all(exp.Column) if not c.table}
 95    else:
 96        order_refs = set()
 97
 98    new_selections = []
 99    removed = False
100    star = False
101
102    select_all = SELECT_ALL in parent_selections
103
104    for selection in scope.expression.selects:
105        name = selection.alias_or_name
106
107        if select_all or name in parent_selections or name in order_refs or alias_count > 0:
108            new_selections.append(selection)
109            alias_count -= 1
110        else:
111            if selection.is_star:
112                star = True
113            removed = True
114
115    if star:
116        resolver = Resolver(scope, schema)
117        names = {s.alias_or_name for s in new_selections}
118
119        for name in sorted(parent_selections):
120            if name not in names:
121                new_selections.append(
122                    alias(exp.column(name, table=resolver.get_table(name)), name, copy=False)
123                )
124
125    # If there are no remaining selections, just select a single constant
126    if not new_selections:
127        new_selections.append(DEFAULT_SELECTION())
128
129    scope.expression.select(*new_selections, append=False, copy=False)
130
131    if removed:
132        scope.clear_cache()
SELECT_ALL = <object object>
def DEFAULT_SELECTION():
13DEFAULT_SELECTION = lambda: alias("1", "_")
def pushdown_projections(expression, schema=None, remove_unused_selections=True):
16def pushdown_projections(expression, schema=None, remove_unused_selections=True):
17    """
18    Rewrite sqlglot AST to remove unused columns projections.
19
20    Example:
21        >>> import sqlglot
22        >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y"
23        >>> expression = sqlglot.parse_one(sql)
24        >>> pushdown_projections(expression).sql()
25        'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y'
26
27    Args:
28        expression (sqlglot.Expression): expression to optimize
29        remove_unused_selections (bool): remove selects that are unused
30    Returns:
31        sqlglot.Expression: optimized expression
32    """
33    # Map of Scope to all columns being selected by outer queries.
34    schema = ensure_schema(schema)
35    source_column_alias_count = {}
36    referenced_columns = defaultdict(set)
37
38    # We build the scope tree (which is traversed in DFS postorder), then iterate
39    # over the result in reverse order. This should ensure that the set of selected
40    # columns for a particular scope are completely build by the time we get to it.
41    for scope in reversed(traverse_scope(expression)):
42        parent_selections = referenced_columns.get(scope, {SELECT_ALL})
43        alias_count = source_column_alias_count.get(scope, 0)
44
45        if scope.expression.args.get("distinct") or (scope.parent and scope.parent.pivots):
46            # We can't remove columns SELECT DISTINCT nor UNION DISTINCT. The same holds if
47            # we select from a pivoted source in the parent scope.
48            parent_selections = {SELECT_ALL}
49
50        if isinstance(scope.expression, exp.Union):
51            left, right = scope.union_scopes
52            referenced_columns[left] = parent_selections
53
54            if any(select.is_star for select in right.expression.selects):
55                referenced_columns[right] = parent_selections
56            elif not any(select.is_star for select in left.expression.selects):
57                referenced_columns[right] = [
58                    right.expression.selects[i].alias_or_name
59                    for i, select in enumerate(left.expression.selects)
60                    if SELECT_ALL in parent_selections or select.alias_or_name in parent_selections
61                ]
62
63        if isinstance(scope.expression, exp.Select):
64            if remove_unused_selections:
65                _remove_unused_selections(scope, parent_selections, schema, alias_count)
66
67            if scope.expression.is_star:
68                continue
69
70            # Group columns by source name
71            selects = defaultdict(set)
72            for col in scope.columns:
73                table_name = col.table
74                col_name = col.name
75                selects[table_name].add(col_name)
76
77            # Push the selected columns down to the next scope
78            for name, (node, source) in scope.selected_sources.items():
79                if isinstance(source, Scope):
80                    columns = selects.get(name) or set()
81                    referenced_columns[source].update(columns)
82
83                column_aliases = node.alias_column_names
84                if column_aliases:
85                    source_column_alias_count[source] = len(column_aliases)
86
87    return expression

Rewrite sqlglot AST to remove unused columns projections.

Example:
>>> import sqlglot
>>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y"
>>> expression = sqlglot.parse_one(sql)
>>> pushdown_projections(expression).sql()
'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y'
Arguments:
  • expression (sqlglot.Expression): expression to optimize
  • remove_unused_selections (bool): remove selects that are unused
Returns:

sqlglot.Expression: optimized expression