Edit on GitHub

sqlglot.optimizer.pushdown_projections

  1from collections import defaultdict
  2
  3from sqlglot import alias, exp
  4from sqlglot.optimizer.qualify_columns import Resolver
  5from sqlglot.optimizer.scope import Scope, traverse_scope
  6from sqlglot.schema import ensure_schema
  7
  8# Sentinel value that means an outer query selecting ALL columns
  9SELECT_ALL = object()
 10
 11# Selection to use if selection list is empty
 12DEFAULT_SELECTION = lambda: alias("1", "_")
 13
 14
 15def pushdown_projections(expression, schema=None, remove_unused_selections=True):
 16    """
 17    Rewrite sqlglot AST to remove unused columns projections.
 18
 19    Example:
 20        >>> import sqlglot
 21        >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y"
 22        >>> expression = sqlglot.parse_one(sql)
 23        >>> pushdown_projections(expression).sql()
 24        'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y'
 25
 26    Args:
 27        expression (sqlglot.Expression): expression to optimize
 28        remove_unused_selections (bool): remove selects that are unused
 29    Returns:
 30        sqlglot.Expression: optimized expression
 31    """
 32    # Map of Scope to all columns being selected by outer queries.
 33    schema = ensure_schema(schema)
 34    referenced_columns = defaultdict(set)
 35
 36    # We build the scope tree (which is traversed in DFS postorder), then iterate
 37    # over the result in reverse order. This should ensure that the set of selected
 38    # columns for a particular scope are completely build by the time we get to it.
 39    for scope in reversed(traverse_scope(expression)):
 40        parent_selections = referenced_columns.get(scope, {SELECT_ALL})
 41
 42        if scope.expression.args.get("distinct"):
 43            # We can't remove columns SELECT DISTINCT nor UNION DISTINCT
 44            parent_selections = {SELECT_ALL}
 45
 46        if isinstance(scope.expression, exp.Union):
 47            left, right = scope.union_scopes
 48            referenced_columns[left] = parent_selections
 49
 50            if any(select.is_star for select in right.selects):
 51                referenced_columns[right] = parent_selections
 52            elif not any(select.is_star for select in left.selects):
 53                referenced_columns[right] = [
 54                    right.selects[i].alias_or_name
 55                    for i, select in enumerate(left.selects)
 56                    if SELECT_ALL in parent_selections or select.alias_or_name in parent_selections
 57                ]
 58
 59        if isinstance(scope.expression, exp.Select):
 60            if remove_unused_selections:
 61                _remove_unused_selections(scope, parent_selections, schema)
 62
 63            # Group columns by source name
 64            selects = defaultdict(set)
 65            for col in scope.columns:
 66                table_name = col.table
 67                col_name = col.name
 68                selects[table_name].add(col_name)
 69
 70            # Push the selected columns down to the next scope
 71            for name, (_, source) in scope.selected_sources.items():
 72                if isinstance(source, Scope):
 73                    columns = selects.get(name) or set()
 74                    referenced_columns[source].update(columns)
 75
 76    return expression
 77
 78
 79def _remove_unused_selections(scope, parent_selections, schema):
 80    order = scope.expression.args.get("order")
 81
 82    if order:
 83        # Assume columns without a qualified table are references to output columns
 84        order_refs = {c.name for c in order.find_all(exp.Column) if not c.table}
 85    else:
 86        order_refs = set()
 87
 88    new_selections = []
 89    removed = False
 90    star = False
 91
 92    for selection in scope.selects:
 93        name = selection.alias_or_name
 94
 95        if SELECT_ALL in parent_selections or name in parent_selections or name in order_refs:
 96            new_selections.append(selection)
 97        else:
 98            if selection.is_star:
 99                star = True
100            removed = True
101
102    if star:
103        resolver = Resolver(scope, schema)
104        names = {s.alias_or_name for s in new_selections}
105
106        for name in sorted(parent_selections):
107            if name not in names:
108                new_selections.append(alias(exp.column(name, table=resolver.get_table(name)), name))
109
110    # If there are no remaining selections, just select a single constant
111    if not new_selections:
112        new_selections.append(DEFAULT_SELECTION())
113
114    scope.expression.select(*new_selections, append=False, copy=False)
115
116    if removed:
117        scope.clear_cache()
def DEFAULT_SELECTION():
13DEFAULT_SELECTION = lambda: alias("1", "_")
def pushdown_projections(expression, schema=None, remove_unused_selections=True):
16def pushdown_projections(expression, schema=None, remove_unused_selections=True):
17    """
18    Rewrite sqlglot AST to remove unused columns projections.
19
20    Example:
21        >>> import sqlglot
22        >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y"
23        >>> expression = sqlglot.parse_one(sql)
24        >>> pushdown_projections(expression).sql()
25        'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y'
26
27    Args:
28        expression (sqlglot.Expression): expression to optimize
29        remove_unused_selections (bool): remove selects that are unused
30    Returns:
31        sqlglot.Expression: optimized expression
32    """
33    # Map of Scope to all columns being selected by outer queries.
34    schema = ensure_schema(schema)
35    referenced_columns = defaultdict(set)
36
37    # We build the scope tree (which is traversed in DFS postorder), then iterate
38    # over the result in reverse order. This should ensure that the set of selected
39    # columns for a particular scope are completely build by the time we get to it.
40    for scope in reversed(traverse_scope(expression)):
41        parent_selections = referenced_columns.get(scope, {SELECT_ALL})
42
43        if scope.expression.args.get("distinct"):
44            # We can't remove columns SELECT DISTINCT nor UNION DISTINCT
45            parent_selections = {SELECT_ALL}
46
47        if isinstance(scope.expression, exp.Union):
48            left, right = scope.union_scopes
49            referenced_columns[left] = parent_selections
50
51            if any(select.is_star for select in right.selects):
52                referenced_columns[right] = parent_selections
53            elif not any(select.is_star for select in left.selects):
54                referenced_columns[right] = [
55                    right.selects[i].alias_or_name
56                    for i, select in enumerate(left.selects)
57                    if SELECT_ALL in parent_selections or select.alias_or_name in parent_selections
58                ]
59
60        if isinstance(scope.expression, exp.Select):
61            if remove_unused_selections:
62                _remove_unused_selections(scope, parent_selections, schema)
63
64            # Group columns by source name
65            selects = defaultdict(set)
66            for col in scope.columns:
67                table_name = col.table
68                col_name = col.name
69                selects[table_name].add(col_name)
70
71            # Push the selected columns down to the next scope
72            for name, (_, source) in scope.selected_sources.items():
73                if isinstance(source, Scope):
74                    columns = selects.get(name) or set()
75                    referenced_columns[source].update(columns)
76
77    return expression

Rewrite sqlglot AST to remove unused columns projections.

Example:
>>> import sqlglot
>>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y"
>>> expression = sqlglot.parse_one(sql)
>>> pushdown_projections(expression).sql()
'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y'
Arguments:
  • expression (sqlglot.Expression): expression to optimize
  • remove_unused_selections (bool): remove selects that are unused
Returns:

sqlglot.Expression: optimized expression