Edit on GitHub

sqlglot.optimizer.pushdown_projections

  1from collections import defaultdict
  2
  3from sqlglot import alias, exp
  4from sqlglot.helper import flatten
  5from sqlglot.optimizer.qualify_columns import Resolver
  6from sqlglot.optimizer.scope import Scope, traverse_scope
  7from sqlglot.schema import ensure_schema
  8
  9# Sentinel value that means an outer query selecting ALL columns
 10SELECT_ALL = object()
 11
 12# Selection to use if selection list is empty
 13DEFAULT_SELECTION = lambda: alias("1", "_")
 14
 15
 16def pushdown_projections(expression, schema=None):
 17    """
 18    Rewrite sqlglot AST to remove unused columns projections.
 19
 20    Example:
 21        >>> import sqlglot
 22        >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y"
 23        >>> expression = sqlglot.parse_one(sql)
 24        >>> pushdown_projections(expression).sql()
 25        'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y'
 26
 27    Args:
 28        expression (sqlglot.Expression): expression to optimize
 29    Returns:
 30        sqlglot.Expression: optimized expression
 31    """
 32    # Map of Scope to all columns being selected by outer queries.
 33    schema = ensure_schema(schema)
 34    referenced_columns = defaultdict(set)
 35
 36    # We build the scope tree (which is traversed in DFS postorder), then iterate
 37    # over the result in reverse order. This should ensure that the set of selected
 38    # columns for a particular scope are completely build by the time we get to it.
 39    for scope in reversed(traverse_scope(expression)):
 40        parent_selections = referenced_columns.get(scope, {SELECT_ALL})
 41
 42        if scope.expression.args.get("distinct"):
 43            # We can't remove columns SELECT DISTINCT nor UNION DISTINCT
 44            parent_selections = {SELECT_ALL}
 45
 46        if isinstance(scope.expression, exp.Union):
 47            left, right = scope.union_scopes
 48            referenced_columns[left] = parent_selections
 49
 50            if any(select.is_star for select in right.selects):
 51                referenced_columns[right] = parent_selections
 52            elif not any(select.is_star for select in left.selects):
 53                referenced_columns[right] = [
 54                    right.selects[i].alias_or_name
 55                    for i, select in enumerate(left.selects)
 56                    if SELECT_ALL in parent_selections or select.alias_or_name in parent_selections
 57                ]
 58
 59        if isinstance(scope.expression, exp.Select):
 60            _remove_unused_selections(scope, parent_selections, schema)
 61
 62            # Group columns by source name
 63            selects = defaultdict(set)
 64            for col in scope.columns:
 65                table_name = col.table
 66                col_name = col.name
 67                selects[table_name].add(col_name)
 68
 69            # Push the selected columns down to the next scope
 70            for name, (_, source) in scope.selected_sources.items():
 71                if isinstance(source, Scope):
 72                    columns = selects.get(name) or set()
 73                    referenced_columns[source].update(columns)
 74
 75    return expression
 76
 77
 78def _remove_unused_selections(scope, parent_selections, schema):
 79    order = scope.expression.args.get("order")
 80
 81    if order:
 82        # Assume columns without a qualified table are references to output columns
 83        order_refs = {c.name for c in order.find_all(exp.Column) if not c.table}
 84    else:
 85        order_refs = set()
 86
 87    new_selections = defaultdict(list)
 88    removed = False
 89    star = False
 90    for selection in scope.selects:
 91        name = selection.alias_or_name
 92
 93        if SELECT_ALL in parent_selections or name in parent_selections or name in order_refs:
 94            new_selections[name].append(selection)
 95        else:
 96            if selection.is_star:
 97                star = True
 98            removed = True
 99
100    if star:
101        resolver = Resolver(scope, schema)
102
103        for name in sorted(parent_selections):
104            if name not in new_selections:
105                new_selections[name].append(
106                    alias(exp.column(name, table=resolver.get_table(name)), name)
107                )
108
109    # If there are no remaining selections, just select a single constant
110    if not new_selections:
111        new_selections[""].append(DEFAULT_SELECTION())
112
113    scope.expression.select(*flatten(new_selections.values()), append=False, copy=False)
114
115    if removed:
116        scope.clear_cache()
def DEFAULT_SELECTION():
14DEFAULT_SELECTION = lambda: alias("1", "_")
def pushdown_projections(expression, schema=None):
17def pushdown_projections(expression, schema=None):
18    """
19    Rewrite sqlglot AST to remove unused columns projections.
20
21    Example:
22        >>> import sqlglot
23        >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y"
24        >>> expression = sqlglot.parse_one(sql)
25        >>> pushdown_projections(expression).sql()
26        'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y'
27
28    Args:
29        expression (sqlglot.Expression): expression to optimize
30    Returns:
31        sqlglot.Expression: optimized expression
32    """
33    # Map of Scope to all columns being selected by outer queries.
34    schema = ensure_schema(schema)
35    referenced_columns = defaultdict(set)
36
37    # We build the scope tree (which is traversed in DFS postorder), then iterate
38    # over the result in reverse order. This should ensure that the set of selected
39    # columns for a particular scope are completely build by the time we get to it.
40    for scope in reversed(traverse_scope(expression)):
41        parent_selections = referenced_columns.get(scope, {SELECT_ALL})
42
43        if scope.expression.args.get("distinct"):
44            # We can't remove columns SELECT DISTINCT nor UNION DISTINCT
45            parent_selections = {SELECT_ALL}
46
47        if isinstance(scope.expression, exp.Union):
48            left, right = scope.union_scopes
49            referenced_columns[left] = parent_selections
50
51            if any(select.is_star for select in right.selects):
52                referenced_columns[right] = parent_selections
53            elif not any(select.is_star for select in left.selects):
54                referenced_columns[right] = [
55                    right.selects[i].alias_or_name
56                    for i, select in enumerate(left.selects)
57                    if SELECT_ALL in parent_selections or select.alias_or_name in parent_selections
58                ]
59
60        if isinstance(scope.expression, exp.Select):
61            _remove_unused_selections(scope, parent_selections, schema)
62
63            # Group columns by source name
64            selects = defaultdict(set)
65            for col in scope.columns:
66                table_name = col.table
67                col_name = col.name
68                selects[table_name].add(col_name)
69
70            # Push the selected columns down to the next scope
71            for name, (_, source) in scope.selected_sources.items():
72                if isinstance(source, Scope):
73                    columns = selects.get(name) or set()
74                    referenced_columns[source].update(columns)
75
76    return expression

Rewrite sqlglot AST to remove unused columns projections.

Example:
>>> import sqlglot
>>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y"
>>> expression = sqlglot.parse_one(sql)
>>> pushdown_projections(expression).sql()
'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y'
Arguments:
  • expression (sqlglot.Expression): expression to optimize
Returns:

sqlglot.Expression: optimized expression