Edit on GitHub

sqlglot.optimizer.pushdown_projections

  1from collections import defaultdict
  2
  3from sqlglot import alias, exp
  4from sqlglot.optimizer.qualify_columns import Resolver
  5from sqlglot.optimizer.scope import Scope, traverse_scope
  6from sqlglot.schema import ensure_schema
  7
  8# Sentinel value that means an outer query selecting ALL columns
  9SELECT_ALL = object()
 10
 11
 12# Selection to use if selection list is empty
 13def default_selection(is_agg: bool) -> exp.Alias:
 14    return alias(exp.Max(this=exp.Literal.number(1)) if is_agg else "1", "_")
 15
 16
 17def pushdown_projections(expression, schema=None, remove_unused_selections=True):
 18    """
 19    Rewrite sqlglot AST to remove unused columns projections.
 20
 21    Example:
 22        >>> import sqlglot
 23        >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y"
 24        >>> expression = sqlglot.parse_one(sql)
 25        >>> pushdown_projections(expression).sql()
 26        'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y'
 27
 28    Args:
 29        expression (sqlglot.Expression): expression to optimize
 30        remove_unused_selections (bool): remove selects that are unused
 31    Returns:
 32        sqlglot.Expression: optimized expression
 33    """
 34    # Map of Scope to all columns being selected by outer queries.
 35    schema = ensure_schema(schema)
 36    source_column_alias_count = {}
 37    referenced_columns = defaultdict(set)
 38
 39    # We build the scope tree (which is traversed in DFS postorder), then iterate
 40    # over the result in reverse order. This should ensure that the set of selected
 41    # columns for a particular scope are completely build by the time we get to it.
 42    for scope in reversed(traverse_scope(expression)):
 43        parent_selections = referenced_columns.get(scope, {SELECT_ALL})
 44        alias_count = source_column_alias_count.get(scope, 0)
 45
 46        # We can't remove columns SELECT DISTINCT nor UNION DISTINCT.
 47        if scope.expression.args.get("distinct"):
 48            parent_selections = {SELECT_ALL}
 49
 50        if isinstance(scope.expression, exp.Union):
 51            left, right = scope.union_scopes
 52            referenced_columns[left] = parent_selections
 53
 54            if any(select.is_star for select in right.expression.selects):
 55                referenced_columns[right] = parent_selections
 56            elif not any(select.is_star for select in left.expression.selects):
 57                if scope.expression.args.get("by_name"):
 58                    referenced_columns[right] = referenced_columns[left]
 59                else:
 60                    referenced_columns[right] = [
 61                        right.expression.selects[i].alias_or_name
 62                        for i, select in enumerate(left.expression.selects)
 63                        if SELECT_ALL in parent_selections
 64                        or select.alias_or_name in parent_selections
 65                    ]
 66
 67        if isinstance(scope.expression, exp.Select):
 68            if remove_unused_selections:
 69                _remove_unused_selections(scope, parent_selections, schema, alias_count)
 70
 71            if scope.expression.is_star:
 72                continue
 73
 74            # Group columns by source name
 75            selects = defaultdict(set)
 76            for col in scope.columns:
 77                table_name = col.table
 78                col_name = col.name
 79                selects[table_name].add(col_name)
 80
 81            # Push the selected columns down to the next scope
 82            for name, (node, source) in scope.selected_sources.items():
 83                if isinstance(source, Scope):
 84                    columns = {SELECT_ALL} if scope.pivots else selects.get(name) or set()
 85                    referenced_columns[source].update(columns)
 86
 87                column_aliases = node.alias_column_names
 88                if column_aliases:
 89                    source_column_alias_count[source] = len(column_aliases)
 90
 91    return expression
 92
 93
 94def _remove_unused_selections(scope, parent_selections, schema, alias_count):
 95    order = scope.expression.args.get("order")
 96
 97    if order:
 98        # Assume columns without a qualified table are references to output columns
 99        order_refs = {c.name for c in order.find_all(exp.Column) if not c.table}
100    else:
101        order_refs = set()
102
103    new_selections = []
104    removed = False
105    star = False
106    is_agg = False
107
108    select_all = SELECT_ALL in parent_selections
109
110    for selection in scope.expression.selects:
111        name = selection.alias_or_name
112
113        if select_all or name in parent_selections or name in order_refs or alias_count > 0:
114            new_selections.append(selection)
115            alias_count -= 1
116        else:
117            if selection.is_star:
118                star = True
119            removed = True
120
121        if not is_agg and selection.find(exp.AggFunc):
122            is_agg = True
123
124    if star:
125        resolver = Resolver(scope, schema)
126        names = {s.alias_or_name for s in new_selections}
127
128        for name in sorted(parent_selections):
129            if name not in names:
130                new_selections.append(
131                    alias(exp.column(name, table=resolver.get_table(name)), name, copy=False)
132                )
133
134    # If there are no remaining selections, just select a single constant
135    if not new_selections:
136        new_selections.append(default_selection(is_agg))
137
138    scope.expression.select(*new_selections, append=False, copy=False)
139
140    if removed:
141        scope.clear_cache()
SELECT_ALL = <object object>
def default_selection(is_agg: bool) -> sqlglot.expressions.Alias:
14def default_selection(is_agg: bool) -> exp.Alias:
15    return alias(exp.Max(this=exp.Literal.number(1)) if is_agg else "1", "_")
def pushdown_projections(expression, schema=None, remove_unused_selections=True):
18def pushdown_projections(expression, schema=None, remove_unused_selections=True):
19    """
20    Rewrite sqlglot AST to remove unused columns projections.
21
22    Example:
23        >>> import sqlglot
24        >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y"
25        >>> expression = sqlglot.parse_one(sql)
26        >>> pushdown_projections(expression).sql()
27        'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y'
28
29    Args:
30        expression (sqlglot.Expression): expression to optimize
31        remove_unused_selections (bool): remove selects that are unused
32    Returns:
33        sqlglot.Expression: optimized expression
34    """
35    # Map of Scope to all columns being selected by outer queries.
36    schema = ensure_schema(schema)
37    source_column_alias_count = {}
38    referenced_columns = defaultdict(set)
39
40    # We build the scope tree (which is traversed in DFS postorder), then iterate
41    # over the result in reverse order. This should ensure that the set of selected
42    # columns for a particular scope are completely build by the time we get to it.
43    for scope in reversed(traverse_scope(expression)):
44        parent_selections = referenced_columns.get(scope, {SELECT_ALL})
45        alias_count = source_column_alias_count.get(scope, 0)
46
47        # We can't remove columns SELECT DISTINCT nor UNION DISTINCT.
48        if scope.expression.args.get("distinct"):
49            parent_selections = {SELECT_ALL}
50
51        if isinstance(scope.expression, exp.Union):
52            left, right = scope.union_scopes
53            referenced_columns[left] = parent_selections
54
55            if any(select.is_star for select in right.expression.selects):
56                referenced_columns[right] = parent_selections
57            elif not any(select.is_star for select in left.expression.selects):
58                if scope.expression.args.get("by_name"):
59                    referenced_columns[right] = referenced_columns[left]
60                else:
61                    referenced_columns[right] = [
62                        right.expression.selects[i].alias_or_name
63                        for i, select in enumerate(left.expression.selects)
64                        if SELECT_ALL in parent_selections
65                        or select.alias_or_name in parent_selections
66                    ]
67
68        if isinstance(scope.expression, exp.Select):
69            if remove_unused_selections:
70                _remove_unused_selections(scope, parent_selections, schema, alias_count)
71
72            if scope.expression.is_star:
73                continue
74
75            # Group columns by source name
76            selects = defaultdict(set)
77            for col in scope.columns:
78                table_name = col.table
79                col_name = col.name
80                selects[table_name].add(col_name)
81
82            # Push the selected columns down to the next scope
83            for name, (node, source) in scope.selected_sources.items():
84                if isinstance(source, Scope):
85                    columns = {SELECT_ALL} if scope.pivots else selects.get(name) or set()
86                    referenced_columns[source].update(columns)
87
88                column_aliases = node.alias_column_names
89                if column_aliases:
90                    source_column_alias_count[source] = len(column_aliases)
91
92    return expression

Rewrite sqlglot AST to remove unused columns projections.

Example:
>>> import sqlglot
>>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y"
>>> expression = sqlglot.parse_one(sql)
>>> pushdown_projections(expression).sql()
'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y'
Arguments:
  • expression (sqlglot.Expression): expression to optimize
  • remove_unused_selections (bool): remove selects that are unused
Returns:

sqlglot.Expression: optimized expression