Edit on GitHub

sqlglot.optimizer.pushdown_projections

  1from collections import defaultdict
  2
  3from sqlglot import alias, exp
  4from sqlglot.helper import flatten
  5from sqlglot.optimizer.qualify_columns import Resolver
  6from sqlglot.optimizer.scope import Scope, traverse_scope
  7from sqlglot.schema import ensure_schema
  8
  9# Sentinel value that means an outer query selecting ALL columns
 10SELECT_ALL = object()
 11
 12# Selection to use if selection list is empty
 13DEFAULT_SELECTION = lambda: alias("1", "_")
 14
 15
 16def pushdown_projections(expression, schema=None, remove_unused_selections=True):
 17    """
 18    Rewrite sqlglot AST to remove unused columns projections.
 19
 20    Example:
 21        >>> import sqlglot
 22        >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y"
 23        >>> expression = sqlglot.parse_one(sql)
 24        >>> pushdown_projections(expression).sql()
 25        'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y'
 26
 27    Args:
 28        expression (sqlglot.Expression): expression to optimize
 29        remove_unused_selections (bool): remove selects that are unused
 30    Returns:
 31        sqlglot.Expression: optimized expression
 32    """
 33    # Map of Scope to all columns being selected by outer queries.
 34    schema = ensure_schema(schema)
 35    referenced_columns = defaultdict(set)
 36
 37    # We build the scope tree (which is traversed in DFS postorder), then iterate
 38    # over the result in reverse order. This should ensure that the set of selected
 39    # columns for a particular scope are completely build by the time we get to it.
 40    for scope in reversed(traverse_scope(expression)):
 41        parent_selections = referenced_columns.get(scope, {SELECT_ALL})
 42
 43        if scope.expression.args.get("distinct"):
 44            # We can't remove columns SELECT DISTINCT nor UNION DISTINCT
 45            parent_selections = {SELECT_ALL}
 46
 47        if isinstance(scope.expression, exp.Union):
 48            left, right = scope.union_scopes
 49            referenced_columns[left] = parent_selections
 50
 51            if any(select.is_star for select in right.selects):
 52                referenced_columns[right] = parent_selections
 53            elif not any(select.is_star for select in left.selects):
 54                referenced_columns[right] = [
 55                    right.selects[i].alias_or_name
 56                    for i, select in enumerate(left.selects)
 57                    if SELECT_ALL in parent_selections or select.alias_or_name in parent_selections
 58                ]
 59
 60        if isinstance(scope.expression, exp.Select):
 61            if remove_unused_selections:
 62                _remove_unused_selections(scope, parent_selections, schema)
 63
 64            # Group columns by source name
 65            selects = defaultdict(set)
 66            for col in scope.columns:
 67                table_name = col.table
 68                col_name = col.name
 69                selects[table_name].add(col_name)
 70
 71            # Push the selected columns down to the next scope
 72            for name, (_, source) in scope.selected_sources.items():
 73                if isinstance(source, Scope):
 74                    columns = selects.get(name) or set()
 75                    referenced_columns[source].update(columns)
 76
 77    return expression
 78
 79
 80def _remove_unused_selections(scope, parent_selections, schema):
 81    order = scope.expression.args.get("order")
 82
 83    if order:
 84        # Assume columns without a qualified table are references to output columns
 85        order_refs = {c.name for c in order.find_all(exp.Column) if not c.table}
 86    else:
 87        order_refs = set()
 88
 89    new_selections = defaultdict(list)
 90    removed = False
 91    star = False
 92    for selection in scope.selects:
 93        name = selection.alias_or_name
 94
 95        if SELECT_ALL in parent_selections or name in parent_selections or name in order_refs:
 96            new_selections[name].append(selection)
 97        else:
 98            if selection.is_star:
 99                star = True
100            removed = True
101
102    if star:
103        resolver = Resolver(scope, schema)
104
105        for name in sorted(parent_selections):
106            if name not in new_selections:
107                new_selections[name].append(
108                    alias(exp.column(name, table=resolver.get_table(name)), name)
109                )
110
111    # If there are no remaining selections, just select a single constant
112    if not new_selections:
113        new_selections[""].append(DEFAULT_SELECTION())
114
115    scope.expression.select(*flatten(new_selections.values()), append=False, copy=False)
116
117    if removed:
118        scope.clear_cache()
def DEFAULT_SELECTION():
14DEFAULT_SELECTION = lambda: alias("1", "_")
def pushdown_projections(expression, schema=None, remove_unused_selections=True):
17def pushdown_projections(expression, schema=None, remove_unused_selections=True):
18    """
19    Rewrite sqlglot AST to remove unused columns projections.
20
21    Example:
22        >>> import sqlglot
23        >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y"
24        >>> expression = sqlglot.parse_one(sql)
25        >>> pushdown_projections(expression).sql()
26        'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y'
27
28    Args:
29        expression (sqlglot.Expression): expression to optimize
30        remove_unused_selections (bool): remove selects that are unused
31    Returns:
32        sqlglot.Expression: optimized expression
33    """
34    # Map of Scope to all columns being selected by outer queries.
35    schema = ensure_schema(schema)
36    referenced_columns = defaultdict(set)
37
38    # We build the scope tree (which is traversed in DFS postorder), then iterate
39    # over the result in reverse order. This should ensure that the set of selected
40    # columns for a particular scope are completely build by the time we get to it.
41    for scope in reversed(traverse_scope(expression)):
42        parent_selections = referenced_columns.get(scope, {SELECT_ALL})
43
44        if scope.expression.args.get("distinct"):
45            # We can't remove columns SELECT DISTINCT nor UNION DISTINCT
46            parent_selections = {SELECT_ALL}
47
48        if isinstance(scope.expression, exp.Union):
49            left, right = scope.union_scopes
50            referenced_columns[left] = parent_selections
51
52            if any(select.is_star for select in right.selects):
53                referenced_columns[right] = parent_selections
54            elif not any(select.is_star for select in left.selects):
55                referenced_columns[right] = [
56                    right.selects[i].alias_or_name
57                    for i, select in enumerate(left.selects)
58                    if SELECT_ALL in parent_selections or select.alias_or_name in parent_selections
59                ]
60
61        if isinstance(scope.expression, exp.Select):
62            if remove_unused_selections:
63                _remove_unused_selections(scope, parent_selections, schema)
64
65            # Group columns by source name
66            selects = defaultdict(set)
67            for col in scope.columns:
68                table_name = col.table
69                col_name = col.name
70                selects[table_name].add(col_name)
71
72            # Push the selected columns down to the next scope
73            for name, (_, source) in scope.selected_sources.items():
74                if isinstance(source, Scope):
75                    columns = selects.get(name) or set()
76                    referenced_columns[source].update(columns)
77
78    return expression

Rewrite sqlglot AST to remove unused columns projections.

Example:
>>> import sqlglot
>>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y"
>>> expression = sqlglot.parse_one(sql)
>>> pushdown_projections(expression).sql()
'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y'
Arguments:
  • expression (sqlglot.Expression): expression to optimize
  • remove_unused_selections (bool): remove selects that are unused
Returns:

sqlglot.Expression: optimized expression