Edit on GitHub

sqlglot.optimizer.pushdown_projections

  1from collections import defaultdict
  2
  3from sqlglot import alias, exp
  4from sqlglot.optimizer.qualify_columns import Resolver
  5from sqlglot.optimizer.scope import Scope, traverse_scope
  6from sqlglot.schema import ensure_schema
  7
  8# Sentinel value that means an outer query selecting ALL columns
  9SELECT_ALL = object()
 10
 11# Selection to use if selection list is empty
 12DEFAULT_SELECTION = lambda: alias("1", "_")
 13
 14
 15def pushdown_projections(expression, schema=None, remove_unused_selections=True):
 16    """
 17    Rewrite sqlglot AST to remove unused columns projections.
 18
 19    Example:
 20        >>> import sqlglot
 21        >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y"
 22        >>> expression = sqlglot.parse_one(sql)
 23        >>> pushdown_projections(expression).sql()
 24        'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y'
 25
 26    Args:
 27        expression (sqlglot.Expression): expression to optimize
 28        remove_unused_selections (bool): remove selects that are unused
 29    Returns:
 30        sqlglot.Expression: optimized expression
 31    """
 32    # Map of Scope to all columns being selected by outer queries.
 33    schema = ensure_schema(schema)
 34    referenced_columns = defaultdict(set)
 35
 36    # We build the scope tree (which is traversed in DFS postorder), then iterate
 37    # over the result in reverse order. This should ensure that the set of selected
 38    # columns for a particular scope are completely build by the time we get to it.
 39    for scope in reversed(traverse_scope(expression)):
 40        parent_selections = referenced_columns.get(scope, {SELECT_ALL})
 41
 42        if scope.expression.args.get("distinct") or scope.parent and scope.parent.pivots:
 43            # We can't remove columns SELECT DISTINCT nor UNION DISTINCT. The same holds if
 44            # we select from a pivoted source in the parent scope.
 45            parent_selections = {SELECT_ALL}
 46
 47        if isinstance(scope.expression, exp.Union):
 48            left, right = scope.union_scopes
 49            referenced_columns[left] = parent_selections
 50
 51            if any(select.is_star for select in right.expression.selects):
 52                referenced_columns[right] = parent_selections
 53            elif not any(select.is_star for select in left.expression.selects):
 54                referenced_columns[right] = [
 55                    right.expression.selects[i].alias_or_name
 56                    for i, select in enumerate(left.expression.selects)
 57                    if SELECT_ALL in parent_selections or select.alias_or_name in parent_selections
 58                ]
 59
 60        if isinstance(scope.expression, exp.Select):
 61            if remove_unused_selections:
 62                _remove_unused_selections(scope, parent_selections, schema)
 63
 64            if scope.expression.is_star:
 65                continue
 66
 67            # Group columns by source name
 68            selects = defaultdict(set)
 69            for col in scope.columns:
 70                table_name = col.table
 71                col_name = col.name
 72                selects[table_name].add(col_name)
 73
 74            # Push the selected columns down to the next scope
 75            for name, (_, source) in scope.selected_sources.items():
 76                if isinstance(source, Scope):
 77                    columns = selects.get(name) or set()
 78                    referenced_columns[source].update(columns)
 79
 80    return expression
 81
 82
 83def _remove_unused_selections(scope, parent_selections, schema):
 84    order = scope.expression.args.get("order")
 85
 86    if order:
 87        # Assume columns without a qualified table are references to output columns
 88        order_refs = {c.name for c in order.find_all(exp.Column) if not c.table}
 89    else:
 90        order_refs = set()
 91
 92    new_selections = []
 93    removed = False
 94    star = False
 95
 96    for selection in scope.expression.selects:
 97        name = selection.alias_or_name
 98
 99        if SELECT_ALL in parent_selections or name in parent_selections or name in order_refs:
100            new_selections.append(selection)
101        else:
102            if selection.is_star:
103                star = True
104            removed = True
105
106    if star:
107        resolver = Resolver(scope, schema)
108        names = {s.alias_or_name for s in new_selections}
109
110        for name in sorted(parent_selections):
111            if name not in names:
112                new_selections.append(
113                    alias(exp.column(name, table=resolver.get_table(name)), name, copy=False)
114                )
115
116    # If there are no remaining selections, just select a single constant
117    if not new_selections:
118        new_selections.append(DEFAULT_SELECTION())
119
120    scope.expression.select(*new_selections, append=False, copy=False)
121
122    if removed:
123        scope.clear_cache()
SELECT_ALL = <object object>
def DEFAULT_SELECTION():
13DEFAULT_SELECTION = lambda: alias("1", "_")
def pushdown_projections(expression, schema=None, remove_unused_selections=True):
16def pushdown_projections(expression, schema=None, remove_unused_selections=True):
17    """
18    Rewrite sqlglot AST to remove unused columns projections.
19
20    Example:
21        >>> import sqlglot
22        >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y"
23        >>> expression = sqlglot.parse_one(sql)
24        >>> pushdown_projections(expression).sql()
25        'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y'
26
27    Args:
28        expression (sqlglot.Expression): expression to optimize
29        remove_unused_selections (bool): remove selects that are unused
30    Returns:
31        sqlglot.Expression: optimized expression
32    """
33    # Map of Scope to all columns being selected by outer queries.
34    schema = ensure_schema(schema)
35    referenced_columns = defaultdict(set)
36
37    # We build the scope tree (which is traversed in DFS postorder), then iterate
38    # over the result in reverse order. This should ensure that the set of selected
39    # columns for a particular scope are completely build by the time we get to it.
40    for scope in reversed(traverse_scope(expression)):
41        parent_selections = referenced_columns.get(scope, {SELECT_ALL})
42
43        if scope.expression.args.get("distinct") or scope.parent and scope.parent.pivots:
44            # We can't remove columns SELECT DISTINCT nor UNION DISTINCT. The same holds if
45            # we select from a pivoted source in the parent scope.
46            parent_selections = {SELECT_ALL}
47
48        if isinstance(scope.expression, exp.Union):
49            left, right = scope.union_scopes
50            referenced_columns[left] = parent_selections
51
52            if any(select.is_star for select in right.expression.selects):
53                referenced_columns[right] = parent_selections
54            elif not any(select.is_star for select in left.expression.selects):
55                referenced_columns[right] = [
56                    right.expression.selects[i].alias_or_name
57                    for i, select in enumerate(left.expression.selects)
58                    if SELECT_ALL in parent_selections or select.alias_or_name in parent_selections
59                ]
60
61        if isinstance(scope.expression, exp.Select):
62            if remove_unused_selections:
63                _remove_unused_selections(scope, parent_selections, schema)
64
65            if scope.expression.is_star:
66                continue
67
68            # Group columns by source name
69            selects = defaultdict(set)
70            for col in scope.columns:
71                table_name = col.table
72                col_name = col.name
73                selects[table_name].add(col_name)
74
75            # Push the selected columns down to the next scope
76            for name, (_, source) in scope.selected_sources.items():
77                if isinstance(source, Scope):
78                    columns = selects.get(name) or set()
79                    referenced_columns[source].update(columns)
80
81    return expression

Rewrite sqlglot AST to remove unused columns projections.

Example:
>>> import sqlglot
>>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y"
>>> expression = sqlglot.parse_one(sql)
>>> pushdown_projections(expression).sql()
'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y'
Arguments:
  • expression (sqlglot.Expression): expression to optimize
  • remove_unused_selections (bool): remove selects that are unused
Returns:

sqlglot.Expression: optimized expression