Edit on GitHub

sqlglot.optimizer.eliminate_joins

  1from sqlglot import expressions as exp
  2from sqlglot.optimizer.normalize import normalized
  3from sqlglot.optimizer.scope import Scope, traverse_scope
  4from sqlglot.optimizer.simplify import simplify
  5
  6
  7def eliminate_joins(expression):
  8    """
  9    Remove unused joins from an expression.
 10
 11    This only removes joins when we know that the join condition doesn't produce duplicate rows.
 12
 13    Example:
 14        >>> import sqlglot
 15        >>> sql = "SELECT x.a FROM x LEFT JOIN (SELECT DISTINCT y.b FROM y) AS y ON x.b = y.b"
 16        >>> expression = sqlglot.parse_one(sql)
 17        >>> eliminate_joins(expression).sql()
 18        'SELECT x.a FROM x'
 19
 20    Args:
 21        expression (sqlglot.Expression): expression to optimize
 22    Returns:
 23        sqlglot.Expression: optimized expression
 24    """
 25    for scope in traverse_scope(expression):
 26        # If any columns in this scope aren't qualified, it's hard to determine if a join isn't used.
 27        # It's probably possible to infer this from the outputs of derived tables.
 28        # But for now, let's just skip this rule.
 29        if scope.unqualified_columns:
 30            continue
 31
 32        joins = scope.expression.args.get("joins", [])
 33
 34        # Reverse the joins so we can remove chains of unused joins
 35        for join in reversed(joins):
 36            alias = join.this.alias_or_name
 37            if _should_eliminate_join(scope, join, alias):
 38                join.pop()
 39                scope.remove_source(alias)
 40    return expression
 41
 42
 43def _should_eliminate_join(scope, join, alias):
 44    inner_source = scope.sources.get(alias)
 45    return (
 46        isinstance(inner_source, Scope)
 47        and not _join_is_used(scope, join, alias)
 48        and (
 49            (join.side == "LEFT" and _is_joined_on_all_unique_outputs(inner_source, join))
 50            or (not join.args.get("on") and _has_single_output_row(inner_source))
 51        )
 52    )
 53
 54
 55def _join_is_used(scope, join, alias):
 56    # We need to find all columns that reference this join.
 57    # But columns in the ON clause shouldn't count.
 58    on = join.args.get("on")
 59    if on:
 60        on_clause_columns = {id(column) for column in on.find_all(exp.Column)}
 61    else:
 62        on_clause_columns = set()
 63    return any(
 64        column for column in scope.source_columns(alias) if id(column) not in on_clause_columns
 65    )
 66
 67
 68def _is_joined_on_all_unique_outputs(scope, join):
 69    unique_outputs = _unique_outputs(scope)
 70    if not unique_outputs:
 71        return False
 72
 73    _, join_keys, _ = join_condition(join)
 74    remaining_unique_outputs = unique_outputs - {c.name for c in join_keys}
 75    return not remaining_unique_outputs
 76
 77
 78def _unique_outputs(scope):
 79    """Determine output columns of `scope` that must have a unique combination per row"""
 80    if scope.expression.args.get("distinct"):
 81        return set(scope.expression.named_selects)
 82
 83    group = scope.expression.args.get("group")
 84    if group:
 85        grouped_expressions = set(group.expressions)
 86        grouped_outputs = set()
 87
 88        unique_outputs = set()
 89        for select in scope.selects:
 90            output = select.unalias()
 91            if output in grouped_expressions:
 92                grouped_outputs.add(output)
 93                unique_outputs.add(select.alias_or_name)
 94
 95        # All the grouped expressions must be in the output
 96        if not grouped_expressions.difference(grouped_outputs):
 97            return unique_outputs
 98        else:
 99            return set()
100
101    if _has_single_output_row(scope):
102        return set(scope.expression.named_selects)
103
104    return set()
105
106
107def _has_single_output_row(scope):
108    return isinstance(scope.expression, exp.Select) and (
109        all(isinstance(e.unalias(), exp.AggFunc) for e in scope.selects)
110        or _is_limit_1(scope)
111        or not scope.expression.args.get("from")
112    )
113
114
115def _is_limit_1(scope):
116    limit = scope.expression.args.get("limit")
117    return limit and limit.expression.this == "1"
118
119
120def join_condition(join):
121    """
122    Extract the join condition from a join expression.
123
124    Args:
125        join (exp.Join)
126    Returns:
127        tuple[list[str], list[str], exp.Expression]:
128            Tuple of (source key, join key, remaining predicate)
129    """
130    name = join.this.alias_or_name
131    on = (join.args.get("on") or exp.true()).copy()
132    source_key = []
133    join_key = []
134
135    def extract_condition(condition):
136        left, right = condition.unnest_operands()
137        left_tables = exp.column_table_names(left)
138        right_tables = exp.column_table_names(right)
139
140        if name in left_tables and name not in right_tables:
141            join_key.append(left)
142            source_key.append(right)
143            condition.replace(exp.true())
144        elif name in right_tables and name not in left_tables:
145            join_key.append(right)
146            source_key.append(left)
147            condition.replace(exp.true())
148
149    # find the join keys
150    # SELECT
151    # FROM x
152    # JOIN y
153    #   ON x.a = y.b AND y.b > 1
154    #
155    # should pull y.b as the join key and x.a as the source key
156    if normalized(on):
157        on = on if isinstance(on, exp.And) else exp.and_(on, exp.true())
158
159        for condition in on.flatten():
160            if isinstance(condition, exp.EQ):
161                extract_condition(condition)
162    elif normalized(on, dnf=True):
163        conditions = None
164
165        for condition in on.flatten():
166            parts = [part for part in condition.flatten() if isinstance(part, exp.EQ)]
167            if conditions is None:
168                conditions = parts
169            else:
170                temp = []
171                for p in parts:
172                    cs = [c for c in conditions if p == c]
173
174                    if cs:
175                        temp.append(p)
176                        temp.extend(cs)
177                conditions = temp
178
179        for condition in conditions:
180            extract_condition(condition)
181
182    on = simplify(on)
183    remaining_condition = None if on == exp.true() else on
184    return source_key, join_key, remaining_condition
def eliminate_joins(expression):
 8def eliminate_joins(expression):
 9    """
10    Remove unused joins from an expression.
11
12    This only removes joins when we know that the join condition doesn't produce duplicate rows.
13
14    Example:
15        >>> import sqlglot
16        >>> sql = "SELECT x.a FROM x LEFT JOIN (SELECT DISTINCT y.b FROM y) AS y ON x.b = y.b"
17        >>> expression = sqlglot.parse_one(sql)
18        >>> eliminate_joins(expression).sql()
19        'SELECT x.a FROM x'
20
21    Args:
22        expression (sqlglot.Expression): expression to optimize
23    Returns:
24        sqlglot.Expression: optimized expression
25    """
26    for scope in traverse_scope(expression):
27        # If any columns in this scope aren't qualified, it's hard to determine if a join isn't used.
28        # It's probably possible to infer this from the outputs of derived tables.
29        # But for now, let's just skip this rule.
30        if scope.unqualified_columns:
31            continue
32
33        joins = scope.expression.args.get("joins", [])
34
35        # Reverse the joins so we can remove chains of unused joins
36        for join in reversed(joins):
37            alias = join.this.alias_or_name
38            if _should_eliminate_join(scope, join, alias):
39                join.pop()
40                scope.remove_source(alias)
41    return expression

Remove unused joins from an expression.

This only removes joins when we know that the join condition doesn't produce duplicate rows.

Example:
>>> import sqlglot
>>> sql = "SELECT x.a FROM x LEFT JOIN (SELECT DISTINCT y.b FROM y) AS y ON x.b = y.b"
>>> expression = sqlglot.parse_one(sql)
>>> eliminate_joins(expression).sql()
'SELECT x.a FROM x'
Arguments:
  • expression (sqlglot.Expression): expression to optimize
Returns:

sqlglot.Expression: optimized expression

def join_condition(join):
121def join_condition(join):
122    """
123    Extract the join condition from a join expression.
124
125    Args:
126        join (exp.Join)
127    Returns:
128        tuple[list[str], list[str], exp.Expression]:
129            Tuple of (source key, join key, remaining predicate)
130    """
131    name = join.this.alias_or_name
132    on = (join.args.get("on") or exp.true()).copy()
133    source_key = []
134    join_key = []
135
136    def extract_condition(condition):
137        left, right = condition.unnest_operands()
138        left_tables = exp.column_table_names(left)
139        right_tables = exp.column_table_names(right)
140
141        if name in left_tables and name not in right_tables:
142            join_key.append(left)
143            source_key.append(right)
144            condition.replace(exp.true())
145        elif name in right_tables and name not in left_tables:
146            join_key.append(right)
147            source_key.append(left)
148            condition.replace(exp.true())
149
150    # find the join keys
151    # SELECT
152    # FROM x
153    # JOIN y
154    #   ON x.a = y.b AND y.b > 1
155    #
156    # should pull y.b as the join key and x.a as the source key
157    if normalized(on):
158        on = on if isinstance(on, exp.And) else exp.and_(on, exp.true())
159
160        for condition in on.flatten():
161            if isinstance(condition, exp.EQ):
162                extract_condition(condition)
163    elif normalized(on, dnf=True):
164        conditions = None
165
166        for condition in on.flatten():
167            parts = [part for part in condition.flatten() if isinstance(part, exp.EQ)]
168            if conditions is None:
169                conditions = parts
170            else:
171                temp = []
172                for p in parts:
173                    cs = [c for c in conditions if p == c]
174
175                    if cs:
176                        temp.append(p)
177                        temp.extend(cs)
178                conditions = temp
179
180        for condition in conditions:
181            extract_condition(condition)
182
183    on = simplify(on)
184    remaining_condition = None if on == exp.true() else on
185    return source_key, join_key, remaining_condition

Extract the join condition from a join expression.

Arguments:
  • join (exp.Join)
Returns:

tuple[list[str], list[str], exp.Expression]: Tuple of (source key, join key, remaining predicate)