Edit on GitHub

sqlglot.optimizer.pushdown_projections

  1from collections import defaultdict
  2
  3from sqlglot import alias, exp
  4from sqlglot.optimizer.scope import Scope, traverse_scope
  5
  6# Sentinel value that means an outer query selecting ALL columns
  7SELECT_ALL = object()
  8
  9# Selection to use if selection list is empty
 10DEFAULT_SELECTION = lambda: alias("1", "_")
 11
 12
 13def pushdown_projections(expression):
 14    """
 15    Rewrite sqlglot AST to remove unused columns projections.
 16
 17    Example:
 18        >>> import sqlglot
 19        >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y"
 20        >>> expression = sqlglot.parse_one(sql)
 21        >>> pushdown_projections(expression).sql()
 22        'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y'
 23
 24    Args:
 25        expression (sqlglot.Expression): expression to optimize
 26    Returns:
 27        sqlglot.Expression: optimized expression
 28    """
 29    # Map of Scope to all columns being selected by outer queries.
 30    referenced_columns = defaultdict(set)
 31    left_union = None
 32    right_union = None
 33    # We build the scope tree (which is traversed in DFS postorder), then iterate
 34    # over the result in reverse order. This should ensure that the set of selected
 35    # columns for a particular scope are completely build by the time we get to it.
 36    for scope in reversed(traverse_scope(expression)):
 37        parent_selections = referenced_columns.get(scope, {SELECT_ALL})
 38
 39        if scope.expression.args.get("distinct"):
 40            # We can't remove columns SELECT DISTINCT nor UNION DISTINCT
 41            parent_selections = {SELECT_ALL}
 42
 43        if isinstance(scope.expression, exp.Union):
 44            left_union, right_union = scope.union_scopes
 45            referenced_columns[left_union] = parent_selections
 46            referenced_columns[right_union] = parent_selections
 47
 48        if isinstance(scope.expression, exp.Select) and scope != right_union:
 49            removed_indexes = _remove_unused_selections(scope, parent_selections)
 50            # The left union is used for column names to select and if we remove columns from the left
 51            # we need to also remove those same columns in the right that were at the same position
 52            if scope is left_union:
 53                _remove_indexed_selections(right_union, removed_indexes)
 54
 55            # Group columns by source name
 56            selects = defaultdict(set)
 57            for col in scope.columns:
 58                table_name = col.table
 59                col_name = col.name
 60                selects[table_name].add(col_name)
 61
 62            # Push the selected columns down to the next scope
 63            for name, (_, source) in scope.selected_sources.items():
 64                if isinstance(source, Scope):
 65                    columns = selects.get(name) or set()
 66                    referenced_columns[source].update(columns)
 67
 68    return expression
 69
 70
 71def _remove_unused_selections(scope, parent_selections):
 72    removed_indexes = []
 73    order = scope.expression.args.get("order")
 74
 75    if order:
 76        # Assume columns without a qualified table are references to output columns
 77        order_refs = {c.name for c in order.find_all(exp.Column) if not c.table}
 78    else:
 79        order_refs = set()
 80
 81    new_selections = []
 82    removed = False
 83    for i, selection in enumerate(scope.selects):
 84        if (
 85            SELECT_ALL in parent_selections
 86            or selection.alias_or_name in parent_selections
 87            or selection.alias_or_name in order_refs
 88        ):
 89            new_selections.append(selection)
 90        else:
 91            removed_indexes.append(i)
 92            removed = True
 93
 94    # If there are no remaining selections, just select a single constant
 95    if not new_selections:
 96        new_selections.append(DEFAULT_SELECTION())
 97
 98    scope.expression.set("expressions", new_selections)
 99    if removed:
100        scope.clear_cache()
101    return removed_indexes
102
103
104def _remove_indexed_selections(scope, indexes_to_remove):
105    new_selections = [
106        selection for i, selection in enumerate(scope.selects) if i not in indexes_to_remove
107    ]
108    if not new_selections:
109        new_selections.append(DEFAULT_SELECTION())
110    scope.expression.set("expressions", new_selections)
def DEFAULT_SELECTION():
11DEFAULT_SELECTION = lambda: alias("1", "_")
def pushdown_projections(expression):
14def pushdown_projections(expression):
15    """
16    Rewrite sqlglot AST to remove unused columns projections.
17
18    Example:
19        >>> import sqlglot
20        >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y"
21        >>> expression = sqlglot.parse_one(sql)
22        >>> pushdown_projections(expression).sql()
23        'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y'
24
25    Args:
26        expression (sqlglot.Expression): expression to optimize
27    Returns:
28        sqlglot.Expression: optimized expression
29    """
30    # Map of Scope to all columns being selected by outer queries.
31    referenced_columns = defaultdict(set)
32    left_union = None
33    right_union = None
34    # We build the scope tree (which is traversed in DFS postorder), then iterate
35    # over the result in reverse order. This should ensure that the set of selected
36    # columns for a particular scope are completely build by the time we get to it.
37    for scope in reversed(traverse_scope(expression)):
38        parent_selections = referenced_columns.get(scope, {SELECT_ALL})
39
40        if scope.expression.args.get("distinct"):
41            # We can't remove columns SELECT DISTINCT nor UNION DISTINCT
42            parent_selections = {SELECT_ALL}
43
44        if isinstance(scope.expression, exp.Union):
45            left_union, right_union = scope.union_scopes
46            referenced_columns[left_union] = parent_selections
47            referenced_columns[right_union] = parent_selections
48
49        if isinstance(scope.expression, exp.Select) and scope != right_union:
50            removed_indexes = _remove_unused_selections(scope, parent_selections)
51            # The left union is used for column names to select and if we remove columns from the left
52            # we need to also remove those same columns in the right that were at the same position
53            if scope is left_union:
54                _remove_indexed_selections(right_union, removed_indexes)
55
56            # Group columns by source name
57            selects = defaultdict(set)
58            for col in scope.columns:
59                table_name = col.table
60                col_name = col.name
61                selects[table_name].add(col_name)
62
63            # Push the selected columns down to the next scope
64            for name, (_, source) in scope.selected_sources.items():
65                if isinstance(source, Scope):
66                    columns = selects.get(name) or set()
67                    referenced_columns[source].update(columns)
68
69    return expression

Rewrite sqlglot AST to remove unused columns projections.

Example:
>>> import sqlglot
>>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y"
>>> expression = sqlglot.parse_one(sql)
>>> pushdown_projections(expression).sql()
'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y'
Arguments:
  • expression (sqlglot.Expression): expression to optimize
Returns:

sqlglot.Expression: optimized expression