Edit on GitHub

sqlglot.optimizer.pushdown_projections

  1from collections import defaultdict
  2
  3from sqlglot import alias, exp
  4from sqlglot.optimizer.qualify_columns import Resolver
  5from sqlglot.optimizer.scope import Scope, traverse_scope
  6from sqlglot.schema import ensure_schema
  7
  8# Sentinel value that means an outer query selecting ALL columns
  9SELECT_ALL = object()
 10
 11# Selection to use if selection list is empty
 12DEFAULT_SELECTION = lambda: alias("1", "_")
 13
 14
 15def pushdown_projections(expression, schema=None, remove_unused_selections=True):
 16    """
 17    Rewrite sqlglot AST to remove unused columns projections.
 18
 19    Example:
 20        >>> import sqlglot
 21        >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y"
 22        >>> expression = sqlglot.parse_one(sql)
 23        >>> pushdown_projections(expression).sql()
 24        'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y'
 25
 26    Args:
 27        expression (sqlglot.Expression): expression to optimize
 28        remove_unused_selections (bool): remove selects that are unused
 29    Returns:
 30        sqlglot.Expression: optimized expression
 31    """
 32    # Map of Scope to all columns being selected by outer queries.
 33    schema = ensure_schema(schema)
 34    referenced_columns = defaultdict(set)
 35
 36    # We build the scope tree (which is traversed in DFS postorder), then iterate
 37    # over the result in reverse order. This should ensure that the set of selected
 38    # columns for a particular scope are completely build by the time we get to it.
 39    for scope in reversed(traverse_scope(expression)):
 40        parent_selections = referenced_columns.get(scope, {SELECT_ALL})
 41
 42        if scope.expression.args.get("distinct") or scope.parent and scope.parent.pivots:
 43            # We can't remove columns SELECT DISTINCT nor UNION DISTINCT. The same holds if
 44            # we select from a pivoted source in the parent scope.
 45            parent_selections = {SELECT_ALL}
 46
 47        if isinstance(scope.expression, exp.Union):
 48            left, right = scope.union_scopes
 49            referenced_columns[left] = parent_selections
 50
 51            if any(select.is_star for select in right.selects):
 52                referenced_columns[right] = parent_selections
 53            elif not any(select.is_star for select in left.selects):
 54                referenced_columns[right] = [
 55                    right.selects[i].alias_or_name
 56                    for i, select in enumerate(left.selects)
 57                    if SELECT_ALL in parent_selections or select.alias_or_name in parent_selections
 58                ]
 59
 60        if isinstance(scope.expression, exp.Select):
 61            if remove_unused_selections:
 62                _remove_unused_selections(scope, parent_selections, schema)
 63
 64            # Group columns by source name
 65            selects = defaultdict(set)
 66            for col in scope.columns:
 67                table_name = col.table
 68                col_name = col.name
 69                selects[table_name].add(col_name)
 70
 71            # Push the selected columns down to the next scope
 72            for name, (_, source) in scope.selected_sources.items():
 73                if isinstance(source, Scope):
 74                    columns = selects.get(name) or set()
 75                    referenced_columns[source].update(columns)
 76
 77    return expression
 78
 79
 80def _remove_unused_selections(scope, parent_selections, schema):
 81    order = scope.expression.args.get("order")
 82
 83    if order:
 84        # Assume columns without a qualified table are references to output columns
 85        order_refs = {c.name for c in order.find_all(exp.Column) if not c.table}
 86    else:
 87        order_refs = set()
 88
 89    new_selections = []
 90    removed = False
 91    star = False
 92
 93    for selection in scope.selects:
 94        name = selection.alias_or_name
 95
 96        if SELECT_ALL in parent_selections or name in parent_selections or name in order_refs:
 97            new_selections.append(selection)
 98        else:
 99            if selection.is_star:
100                star = True
101            removed = True
102
103    if star:
104        resolver = Resolver(scope, schema)
105        names = {s.alias_or_name for s in new_selections}
106
107        for name in sorted(parent_selections):
108            if name not in names:
109                new_selections.append(
110                    alias(exp.column(name, table=resolver.get_table(name)), name, copy=False)
111                )
112
113    # If there are no remaining selections, just select a single constant
114    if not new_selections:
115        new_selections.append(DEFAULT_SELECTION())
116
117    scope.expression.select(*new_selections, append=False, copy=False)
118
119    if removed:
120        scope.clear_cache()
def DEFAULT_SELECTION():
13DEFAULT_SELECTION = lambda: alias("1", "_")
def pushdown_projections(expression, schema=None, remove_unused_selections=True):
16def pushdown_projections(expression, schema=None, remove_unused_selections=True):
17    """
18    Rewrite sqlglot AST to remove unused columns projections.
19
20    Example:
21        >>> import sqlglot
22        >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y"
23        >>> expression = sqlglot.parse_one(sql)
24        >>> pushdown_projections(expression).sql()
25        'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y'
26
27    Args:
28        expression (sqlglot.Expression): expression to optimize
29        remove_unused_selections (bool): remove selects that are unused
30    Returns:
31        sqlglot.Expression: optimized expression
32    """
33    # Map of Scope to all columns being selected by outer queries.
34    schema = ensure_schema(schema)
35    referenced_columns = defaultdict(set)
36
37    # We build the scope tree (which is traversed in DFS postorder), then iterate
38    # over the result in reverse order. This should ensure that the set of selected
39    # columns for a particular scope are completely build by the time we get to it.
40    for scope in reversed(traverse_scope(expression)):
41        parent_selections = referenced_columns.get(scope, {SELECT_ALL})
42
43        if scope.expression.args.get("distinct") or scope.parent and scope.parent.pivots:
44            # We can't remove columns SELECT DISTINCT nor UNION DISTINCT. The same holds if
45            # we select from a pivoted source in the parent scope.
46            parent_selections = {SELECT_ALL}
47
48        if isinstance(scope.expression, exp.Union):
49            left, right = scope.union_scopes
50            referenced_columns[left] = parent_selections
51
52            if any(select.is_star for select in right.selects):
53                referenced_columns[right] = parent_selections
54            elif not any(select.is_star for select in left.selects):
55                referenced_columns[right] = [
56                    right.selects[i].alias_or_name
57                    for i, select in enumerate(left.selects)
58                    if SELECT_ALL in parent_selections or select.alias_or_name in parent_selections
59                ]
60
61        if isinstance(scope.expression, exp.Select):
62            if remove_unused_selections:
63                _remove_unused_selections(scope, parent_selections, schema)
64
65            # Group columns by source name
66            selects = defaultdict(set)
67            for col in scope.columns:
68                table_name = col.table
69                col_name = col.name
70                selects[table_name].add(col_name)
71
72            # Push the selected columns down to the next scope
73            for name, (_, source) in scope.selected_sources.items():
74                if isinstance(source, Scope):
75                    columns = selects.get(name) or set()
76                    referenced_columns[source].update(columns)
77
78    return expression

Rewrite sqlglot AST to remove unused columns projections.

Example:
>>> import sqlglot
>>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y"
>>> expression = sqlglot.parse_one(sql)
>>> pushdown_projections(expression).sql()
'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y'
Arguments:
  • expression (sqlglot.Expression): expression to optimize
  • remove_unused_selections (bool): remove selects that are unused
Returns:

sqlglot.Expression: optimized expression