summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/eliminate_joins.py
blob: 0854336a33628fc3967d1f38b0901bff84cfbc39 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
from sqlglot import expressions as exp
from sqlglot.optimizer.normalize import normalized
from sqlglot.optimizer.scope import Scope, traverse_scope
from sqlglot.optimizer.simplify import simplify


def eliminate_joins(expression):
    """
    Remove unused joins from an expression.

    This only removes joins when we know that the join condition doesn't produce duplicate rows.

    Example:
        >>> import sqlglot
        >>> sql = "SELECT x.a FROM x LEFT JOIN (SELECT DISTINCT y.b FROM y) AS y ON x.b = y.b"
        >>> expression = sqlglot.parse_one(sql)
        >>> eliminate_joins(expression).sql()
        'SELECT x.a FROM x'

    Args:
        expression (sqlglot.Expression): expression to optimize
    Returns:
        sqlglot.Expression: optimized expression
    """
    for scope in traverse_scope(expression):
        # If any columns in this scope aren't qualified, it's hard to determine if a join isn't used.
        # It's probably possible to infer this from the outputs of derived tables.
        # But for now, let's just skip this rule.
        if scope.unqualified_columns:
            continue

        joins = scope.expression.args.get("joins", [])

        # Reverse the joins so we can remove chains of unused joins
        for join in reversed(joins):
            alias = join.this.alias_or_name
            if _should_eliminate_join(scope, join, alias):
                join.pop()
                scope.remove_source(alias)
    return expression


def _should_eliminate_join(scope, join, alias):
    inner_source = scope.sources.get(alias)
    return (
        isinstance(inner_source, Scope)
        and not _join_is_used(scope, join, alias)
        and (
            (join.side == "LEFT" and _is_joined_on_all_unique_outputs(inner_source, join))
            or (not join.args.get("on") and _has_single_output_row(inner_source))
        )
    )


def _join_is_used(scope, join, alias):
    # We need to find all columns that reference this join.
    # But columns in the ON clause shouldn't count.
    on = join.args.get("on")
    if on:
        on_clause_columns = set(id(column) for column in on.find_all(exp.Column))
    else:
        on_clause_columns = set()
    return any(column for column in scope.source_columns(alias) if id(column) not in on_clause_columns)


def _is_joined_on_all_unique_outputs(scope, join):
    unique_outputs = _unique_outputs(scope)
    if not unique_outputs:
        return False

    _, join_keys, _ = join_condition(join)
    remaining_unique_outputs = unique_outputs - set(c.name for c in join_keys)
    return not remaining_unique_outputs


def _unique_outputs(scope):
    """Determine output columns of `scope` that must have a unique combination per row"""
    if scope.expression.args.get("distinct"):
        return set(scope.expression.named_selects)

    group = scope.expression.args.get("group")
    if group:
        grouped_expressions = set(group.expressions)
        grouped_outputs = set()

        unique_outputs = set()
        for select in scope.selects:
            output = select.unalias()
            if output in grouped_expressions:
                grouped_outputs.add(output)
                unique_outputs.add(select.alias_or_name)

        # All the grouped expressions must be in the output
        if not grouped_expressions.difference(grouped_outputs):
            return unique_outputs
        else:
            return set()

    if _has_single_output_row(scope):
        return set(scope.expression.named_selects)

    return set()


def _has_single_output_row(scope):
    return isinstance(scope.expression, exp.Select) and (
        all(isinstance(e.unalias(), exp.AggFunc) for e in scope.selects)
        or _is_limit_1(scope)
        or not scope.expression.args.get("from")
    )


def _is_limit_1(scope):
    limit = scope.expression.args.get("limit")
    return limit and limit.expression.this == "1"


def join_condition(join):
    """
    Extract the join condition from a join expression.

    Args:
        join (exp.Join)
    Returns:
        tuple[list[str], list[str], exp.Expression]:
            Tuple of (source key, join key, remaining predicate)
    """
    name = join.this.alias_or_name
    on = join.args.get("on") or exp.TRUE
    on = on.copy()
    source_key = []
    join_key = []

    # find the join keys
    # SELECT
    # FROM x
    # JOIN y
    #   ON x.a = y.b AND y.b > 1
    #
    # should pull y.b as the join key and x.a as the source key
    if normalized(on):
        for condition in on.flatten() if isinstance(on, exp.And) else [on]:
            if isinstance(condition, exp.EQ):
                left, right = condition.unnest_operands()
                left_tables = exp.column_table_names(left)
                right_tables = exp.column_table_names(right)

                if name in left_tables and name not in right_tables:
                    join_key.append(left)
                    source_key.append(right)
                    condition.replace(exp.TRUE)
                elif name in right_tables and name not in left_tables:
                    join_key.append(right)
                    source_key.append(left)
                    condition.replace(exp.TRUE)

    on = simplify(on)
    remaining_condition = None if on == exp.TRUE else on

    return source_key, join_key, remaining_condition