summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/pushdown_projections.py
blob: abd949257f3970cb021c3cdb4051b0de3ba75c7d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from collections import defaultdict

from sqlglot import alias, exp
from sqlglot.optimizer.scope import Scope, traverse_scope

# Sentinel value that means an outer query selecting ALL columns
SELECT_ALL = object()

# SELECTION TO USE IF SELECTION LIST IS EMPTY
DEFAULT_SELECTION = alias("1", "_")


def pushdown_projections(expression):
    """
    Rewrite sqlglot AST to remove unused columns projections.

    Example:
        >>> import sqlglot
        >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y"
        >>> expression = sqlglot.parse_one(sql)
        >>> pushdown_projections(expression).sql()
        'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y'

    Args:
        expression (sqlglot.Expression): expression to optimize
    Returns:
        sqlglot.Expression: optimized expression
    """
    # Map of Scope to all columns being selected by outer queries.
    referenced_columns = defaultdict(set)
    left_union = None
    right_union = None
    # We build the scope tree (which is traversed in DFS postorder), then iterate
    # over the result in reverse order. This should ensure that the set of selected
    # columns for a particular scope are completely build by the time we get to it.
    for scope in reversed(traverse_scope(expression)):
        parent_selections = referenced_columns.get(scope, {SELECT_ALL})

        if scope.expression.args.get("distinct"):
            # We can't remove columns SELECT DISTINCT nor UNION DISTINCT
            parent_selections = {SELECT_ALL}

        if isinstance(scope.expression, exp.Union):
            left_union, right_union = scope.union_scopes
            referenced_columns[left_union] = parent_selections
            referenced_columns[right_union] = parent_selections

        if isinstance(scope.expression, exp.Select) and scope != right_union:
            removed_indexes = _remove_unused_selections(scope, parent_selections)
            # The left union is used for column names to select and if we remove columns from the left
            # we need to also remove those same columns in the right that were at the same position
            if scope is left_union:
                _remove_indexed_selections(right_union, removed_indexes)

            # Group columns by source name
            selects = defaultdict(set)
            for col in scope.columns:
                table_name = col.table
                col_name = col.name
                selects[table_name].add(col_name)

            # Push the selected columns down to the next scope
            for name, (_, source) in scope.selected_sources.items():
                if isinstance(source, Scope):
                    columns = selects.get(name) or set()
                    referenced_columns[source].update(columns)

    return expression


def _remove_unused_selections(scope, parent_selections):
    removed_indexes = []
    order = scope.expression.args.get("order")

    if order:
        # Assume columns without a qualified table are references to output columns
        order_refs = {c.name for c in order.find_all(exp.Column) if not c.table}
    else:
        order_refs = set()

    new_selections = []
    for i, selection in enumerate(scope.selects):
        if (
            SELECT_ALL in parent_selections
            or selection.alias_or_name in parent_selections
            or selection.alias_or_name in order_refs
        ):
            new_selections.append(selection)
        else:
            removed_indexes.append(i)

    # If there are no remaining selections, just select a single constant
    if not new_selections:
        new_selections.append(DEFAULT_SELECTION)

    scope.expression.set("expressions", new_selections)
    return removed_indexes


def _remove_indexed_selections(scope, indexes_to_remove):
    new_selections = [
        selection for i, selection in enumerate(scope.selects) if i not in indexes_to_remove
    ]
    if not new_selections:
        new_selections.append(DEFAULT_SELECTION)
    scope.expression.set("expressions", new_selections)