Edit on GitHub

sqlglot.optimizer.eliminate_joins

  1from sqlglot import expressions as exp
  2from sqlglot.optimizer.normalize import normalized
  3from sqlglot.optimizer.scope import Scope, traverse_scope
  4
  5
  6def eliminate_joins(expression):
  7    """
  8    Remove unused joins from an expression.
  9
 10    This only removes joins when we know that the join condition doesn't produce duplicate rows.
 11
 12    Example:
 13        >>> import sqlglot
 14        >>> sql = "SELECT x.a FROM x LEFT JOIN (SELECT DISTINCT y.b FROM y) AS y ON x.b = y.b"
 15        >>> expression = sqlglot.parse_one(sql)
 16        >>> eliminate_joins(expression).sql()
 17        'SELECT x.a FROM x'
 18
 19    Args:
 20        expression (sqlglot.Expression): expression to optimize
 21    Returns:
 22        sqlglot.Expression: optimized expression
 23    """
 24    for scope in traverse_scope(expression):
 25        # If any columns in this scope aren't qualified, it's hard to determine if a join isn't used.
 26        # It's probably possible to infer this from the outputs of derived tables.
 27        # But for now, let's just skip this rule.
 28        if scope.unqualified_columns:
 29            continue
 30
 31        joins = scope.expression.args.get("joins", [])
 32
 33        # Reverse the joins so we can remove chains of unused joins
 34        for join in reversed(joins):
 35            alias = join.alias_or_name
 36            if _should_eliminate_join(scope, join, alias):
 37                join.pop()
 38                scope.remove_source(alias)
 39    return expression
 40
 41
 42def _should_eliminate_join(scope, join, alias):
 43    inner_source = scope.sources.get(alias)
 44    return (
 45        isinstance(inner_source, Scope)
 46        and not _join_is_used(scope, join, alias)
 47        and (
 48            (join.side == "LEFT" and _is_joined_on_all_unique_outputs(inner_source, join))
 49            or (not join.args.get("on") and _has_single_output_row(inner_source))
 50        )
 51    )
 52
 53
 54def _join_is_used(scope, join, alias):
 55    # We need to find all columns that reference this join.
 56    # But columns in the ON clause shouldn't count.
 57    on = join.args.get("on")
 58    if on:
 59        on_clause_columns = {id(column) for column in on.find_all(exp.Column)}
 60    else:
 61        on_clause_columns = set()
 62    return any(
 63        column for column in scope.source_columns(alias) if id(column) not in on_clause_columns
 64    )
 65
 66
 67def _is_joined_on_all_unique_outputs(scope, join):
 68    unique_outputs = _unique_outputs(scope)
 69    if not unique_outputs:
 70        return False
 71
 72    _, join_keys, _ = join_condition(join)
 73    remaining_unique_outputs = unique_outputs - {c.name for c in join_keys}
 74    return not remaining_unique_outputs
 75
 76
 77def _unique_outputs(scope):
 78    """Determine output columns of `scope` that must have a unique combination per row"""
 79    if scope.expression.args.get("distinct"):
 80        return set(scope.expression.named_selects)
 81
 82    group = scope.expression.args.get("group")
 83    if group:
 84        grouped_expressions = set(group.expressions)
 85        grouped_outputs = set()
 86
 87        unique_outputs = set()
 88        for select in scope.expression.selects:
 89            output = select.unalias()
 90            if output in grouped_expressions:
 91                grouped_outputs.add(output)
 92                unique_outputs.add(select.alias_or_name)
 93
 94        # All the grouped expressions must be in the output
 95        if not grouped_expressions.difference(grouped_outputs):
 96            return unique_outputs
 97        else:
 98            return set()
 99
100    if _has_single_output_row(scope):
101        return set(scope.expression.named_selects)
102
103    return set()
104
105
106def _has_single_output_row(scope):
107    return isinstance(scope.expression, exp.Select) and (
108        all(isinstance(e.unalias(), exp.AggFunc) for e in scope.expression.selects)
109        or _is_limit_1(scope)
110        or not scope.expression.args.get("from")
111    )
112
113
114def _is_limit_1(scope):
115    limit = scope.expression.args.get("limit")
116    return limit and limit.expression.this == "1"
117
118
119def join_condition(join):
120    """
121    Extract the join condition from a join expression.
122
123    Args:
124        join (exp.Join)
125    Returns:
126        tuple[list[str], list[str], exp.Expression]:
127            Tuple of (source key, join key, remaining predicate)
128    """
129    name = join.alias_or_name
130    on = (join.args.get("on") or exp.true()).copy()
131    source_key = []
132    join_key = []
133
134    def extract_condition(condition):
135        left, right = condition.unnest_operands()
136        left_tables = exp.column_table_names(left)
137        right_tables = exp.column_table_names(right)
138
139        if name in left_tables and name not in right_tables:
140            join_key.append(left)
141            source_key.append(right)
142            condition.replace(exp.true())
143        elif name in right_tables and name not in left_tables:
144            join_key.append(right)
145            source_key.append(left)
146            condition.replace(exp.true())
147
148    # find the join keys
149    # SELECT
150    # FROM x
151    # JOIN y
152    #   ON x.a = y.b AND y.b > 1
153    #
154    # should pull y.b as the join key and x.a as the source key
155    if normalized(on):
156        on = on if isinstance(on, exp.And) else exp.and_(on, exp.true(), copy=False)
157
158        for condition in on.flatten():
159            if isinstance(condition, exp.EQ):
160                extract_condition(condition)
161    elif normalized(on, dnf=True):
162        conditions = None
163
164        for condition in on.flatten():
165            parts = [part for part in condition.flatten() if isinstance(part, exp.EQ)]
166            if conditions is None:
167                conditions = parts
168            else:
169                temp = []
170                for p in parts:
171                    cs = [c for c in conditions if p == c]
172
173                    if cs:
174                        temp.append(p)
175                        temp.extend(cs)
176                conditions = temp
177
178        for condition in conditions:
179            extract_condition(condition)
180
181    return source_key, join_key, on
def eliminate_joins(expression):
 7def eliminate_joins(expression):
 8    """
 9    Remove unused joins from an expression.
10
11    This only removes joins when we know that the join condition doesn't produce duplicate rows.
12
13    Example:
14        >>> import sqlglot
15        >>> sql = "SELECT x.a FROM x LEFT JOIN (SELECT DISTINCT y.b FROM y) AS y ON x.b = y.b"
16        >>> expression = sqlglot.parse_one(sql)
17        >>> eliminate_joins(expression).sql()
18        'SELECT x.a FROM x'
19
20    Args:
21        expression (sqlglot.Expression): expression to optimize
22    Returns:
23        sqlglot.Expression: optimized expression
24    """
25    for scope in traverse_scope(expression):
26        # If any columns in this scope aren't qualified, it's hard to determine if a join isn't used.
27        # It's probably possible to infer this from the outputs of derived tables.
28        # But for now, let's just skip this rule.
29        if scope.unqualified_columns:
30            continue
31
32        joins = scope.expression.args.get("joins", [])
33
34        # Reverse the joins so we can remove chains of unused joins
35        for join in reversed(joins):
36            alias = join.alias_or_name
37            if _should_eliminate_join(scope, join, alias):
38                join.pop()
39                scope.remove_source(alias)
40    return expression

Remove unused joins from an expression.

This only removes joins when we know that the join condition doesn't produce duplicate rows.

Example:
>>> import sqlglot
>>> sql = "SELECT x.a FROM x LEFT JOIN (SELECT DISTINCT y.b FROM y) AS y ON x.b = y.b"
>>> expression = sqlglot.parse_one(sql)
>>> eliminate_joins(expression).sql()
'SELECT x.a FROM x'
Arguments:
  • expression (sqlglot.Expression): expression to optimize
Returns:

sqlglot.Expression: optimized expression

def join_condition(join):
120def join_condition(join):
121    """
122    Extract the join condition from a join expression.
123
124    Args:
125        join (exp.Join)
126    Returns:
127        tuple[list[str], list[str], exp.Expression]:
128            Tuple of (source key, join key, remaining predicate)
129    """
130    name = join.alias_or_name
131    on = (join.args.get("on") or exp.true()).copy()
132    source_key = []
133    join_key = []
134
135    def extract_condition(condition):
136        left, right = condition.unnest_operands()
137        left_tables = exp.column_table_names(left)
138        right_tables = exp.column_table_names(right)
139
140        if name in left_tables and name not in right_tables:
141            join_key.append(left)
142            source_key.append(right)
143            condition.replace(exp.true())
144        elif name in right_tables and name not in left_tables:
145            join_key.append(right)
146            source_key.append(left)
147            condition.replace(exp.true())
148
149    # find the join keys
150    # SELECT
151    # FROM x
152    # JOIN y
153    #   ON x.a = y.b AND y.b > 1
154    #
155    # should pull y.b as the join key and x.a as the source key
156    if normalized(on):
157        on = on if isinstance(on, exp.And) else exp.and_(on, exp.true(), copy=False)
158
159        for condition in on.flatten():
160            if isinstance(condition, exp.EQ):
161                extract_condition(condition)
162    elif normalized(on, dnf=True):
163        conditions = None
164
165        for condition in on.flatten():
166            parts = [part for part in condition.flatten() if isinstance(part, exp.EQ)]
167            if conditions is None:
168                conditions = parts
169            else:
170                temp = []
171                for p in parts:
172                    cs = [c for c in conditions if p == c]
173
174                    if cs:
175                        temp.append(p)
176                        temp.extend(cs)
177                conditions = temp
178
179        for condition in conditions:
180            extract_condition(condition)
181
182    return source_key, join_key, on

Extract the join condition from a join expression.

Arguments:
  • join (exp.Join)
Returns:

tuple[list[str], list[str], exp.Expression]: Tuple of (source key, join key, remaining predicate)