summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/pushdown_projections.py
blob: 3f360f986b31c458a577f00d9ce2d3a1878630b0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from collections import defaultdict

from sqlglot import alias, exp
from sqlglot.helper import flatten
from sqlglot.optimizer.qualify_columns import Resolver
from sqlglot.optimizer.scope import Scope, traverse_scope
from sqlglot.schema import ensure_schema

# Sentinel value that means an outer query selecting ALL columns
SELECT_ALL = object()

# Selection to use if selection list is empty
DEFAULT_SELECTION = lambda: alias("1", "_")


def pushdown_projections(expression, schema=None):
    """
    Rewrite sqlglot AST to remove unused columns projections.

    Example:
        >>> import sqlglot
        >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y"
        >>> expression = sqlglot.parse_one(sql)
        >>> pushdown_projections(expression).sql()
        'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y'

    Args:
        expression (sqlglot.Expression): expression to optimize
    Returns:
        sqlglot.Expression: optimized expression
    """
    # Map of Scope to all columns being selected by outer queries.
    schema = ensure_schema(schema)
    referenced_columns = defaultdict(set)

    # We build the scope tree (which is traversed in DFS postorder), then iterate
    # over the result in reverse order. This should ensure that the set of selected
    # columns for a particular scope are completely build by the time we get to it.
    for scope in reversed(traverse_scope(expression)):
        parent_selections = referenced_columns.get(scope, {SELECT_ALL})

        if scope.expression.args.get("distinct"):
            # We can't remove columns SELECT DISTINCT nor UNION DISTINCT
            parent_selections = {SELECT_ALL}

        if isinstance(scope.expression, exp.Union):
            left, right = scope.union_scopes
            referenced_columns[left] = parent_selections

            if any(select.is_star for select in right.selects):
                referenced_columns[right] = parent_selections
            elif not any(select.is_star for select in left.selects):
                referenced_columns[right] = [
                    right.selects[i].alias_or_name
                    for i, select in enumerate(left.selects)
                    if SELECT_ALL in parent_selections or select.alias_or_name in parent_selections
                ]

        if isinstance(scope.expression, exp.Select):
            _remove_unused_selections(scope, parent_selections, schema)

            # Group columns by source name
            selects = defaultdict(set)
            for col in scope.columns:
                table_name = col.table
                col_name = col.name
                selects[table_name].add(col_name)

            # Push the selected columns down to the next scope
            for name, (_, source) in scope.selected_sources.items():
                if isinstance(source, Scope):
                    columns = selects.get(name) or set()
                    referenced_columns[source].update(columns)

    return expression


def _remove_unused_selections(scope, parent_selections, schema):
    order = scope.expression.args.get("order")

    if order:
        # Assume columns without a qualified table are references to output columns
        order_refs = {c.name for c in order.find_all(exp.Column) if not c.table}
    else:
        order_refs = set()

    new_selections = defaultdict(list)
    removed = False
    star = False
    for selection in scope.selects:
        name = selection.alias_or_name

        if SELECT_ALL in parent_selections or name in parent_selections or name in order_refs:
            new_selections[name].append(selection)
        else:
            if selection.is_star:
                star = True
            removed = True

    if star:
        resolver = Resolver(scope, schema)

        for name in sorted(parent_selections):
            if name not in new_selections:
                new_selections[name].append(
                    alias(exp.column(name, table=resolver.get_table(name)), name)
                )

    # If there are no remaining selections, just select a single constant
    if not new_selections:
        new_selections[""].append(DEFAULT_SELECTION())

    scope.expression.select(*flatten(new_selections.values()), append=False, copy=False)

    if removed:
        scope.clear_cache()