diff options
Diffstat (limited to 'sqlglot/optimizer/eliminate_joins.py')
-rw-r--r-- | sqlglot/optimizer/eliminate_joins.py | 160 |
1 files changed, 160 insertions, 0 deletions
diff --git a/sqlglot/optimizer/eliminate_joins.py b/sqlglot/optimizer/eliminate_joins.py new file mode 100644 index 0000000..0854336 --- /dev/null +++ b/sqlglot/optimizer/eliminate_joins.py @@ -0,0 +1,160 @@ +from sqlglot import expressions as exp +from sqlglot.optimizer.normalize import normalized +from sqlglot.optimizer.scope import Scope, traverse_scope +from sqlglot.optimizer.simplify import simplify + + +def eliminate_joins(expression): + """ + Remove unused joins from an expression. + + This only removes joins when we know that the join condition doesn't produce duplicate rows. + + Example: + >>> import sqlglot + >>> sql = "SELECT x.a FROM x LEFT JOIN (SELECT DISTINCT y.b FROM y) AS y ON x.b = y.b" + >>> expression = sqlglot.parse_one(sql) + >>> eliminate_joins(expression).sql() + 'SELECT x.a FROM x' + + Args: + expression (sqlglot.Expression): expression to optimize + Returns: + sqlglot.Expression: optimized expression + """ + for scope in traverse_scope(expression): + # If any columns in this scope aren't qualified, it's hard to determine if a join isn't used. + # It's probably possible to infer this from the outputs of derived tables. + # But for now, let's just skip this rule. + if scope.unqualified_columns: + continue + + joins = scope.expression.args.get("joins", []) + + # Reverse the joins so we can remove chains of unused joins + for join in reversed(joins): + alias = join.this.alias_or_name + if _should_eliminate_join(scope, join, alias): + join.pop() + scope.remove_source(alias) + return expression + + +def _should_eliminate_join(scope, join, alias): + inner_source = scope.sources.get(alias) + return ( + isinstance(inner_source, Scope) + and not _join_is_used(scope, join, alias) + and ( + (join.side == "LEFT" and _is_joined_on_all_unique_outputs(inner_source, join)) + or (not join.args.get("on") and _has_single_output_row(inner_source)) + ) + ) + + +def _join_is_used(scope, join, alias): + # We need to find all columns that reference this join. + # But columns in the ON clause shouldn't count. + on = join.args.get("on") + if on: + on_clause_columns = set(id(column) for column in on.find_all(exp.Column)) + else: + on_clause_columns = set() + return any(column for column in scope.source_columns(alias) if id(column) not in on_clause_columns) + + +def _is_joined_on_all_unique_outputs(scope, join): + unique_outputs = _unique_outputs(scope) + if not unique_outputs: + return False + + _, join_keys, _ = join_condition(join) + remaining_unique_outputs = unique_outputs - set(c.name for c in join_keys) + return not remaining_unique_outputs + + +def _unique_outputs(scope): + """Determine output columns of `scope` that must have a unique combination per row""" + if scope.expression.args.get("distinct"): + return set(scope.expression.named_selects) + + group = scope.expression.args.get("group") + if group: + grouped_expressions = set(group.expressions) + grouped_outputs = set() + + unique_outputs = set() + for select in scope.selects: + output = select.unalias() + if output in grouped_expressions: + grouped_outputs.add(output) + unique_outputs.add(select.alias_or_name) + + # All the grouped expressions must be in the output + if not grouped_expressions.difference(grouped_outputs): + return unique_outputs + else: + return set() + + if _has_single_output_row(scope): + return set(scope.expression.named_selects) + + return set() + + +def _has_single_output_row(scope): + return isinstance(scope.expression, exp.Select) and ( + all(isinstance(e.unalias(), exp.AggFunc) for e in scope.selects) + or _is_limit_1(scope) + or not scope.expression.args.get("from") + ) + + +def _is_limit_1(scope): + limit = scope.expression.args.get("limit") + return limit and limit.expression.this == "1" + + +def join_condition(join): + """ + Extract the join condition from a join expression. + + Args: + join (exp.Join) + Returns: + tuple[list[str], list[str], exp.Expression]: + Tuple of (source key, join key, remaining predicate) + """ + name = join.this.alias_or_name + on = join.args.get("on") or exp.TRUE + on = on.copy() + source_key = [] + join_key = [] + + # find the join keys + # SELECT + # FROM x + # JOIN y + # ON x.a = y.b AND y.b > 1 + # + # should pull y.b as the join key and x.a as the source key + if normalized(on): + for condition in on.flatten() if isinstance(on, exp.And) else [on]: + if isinstance(condition, exp.EQ): + left, right = condition.unnest_operands() + left_tables = exp.column_table_names(left) + right_tables = exp.column_table_names(right) + + if name in left_tables and name not in right_tables: + join_key.append(left) + source_key.append(right) + condition.replace(exp.TRUE) + elif name in right_tables and name not in left_tables: + join_key.append(right) + source_key.append(left) + condition.replace(exp.TRUE) + + on = simplify(on) + remaining_condition = None if on == exp.TRUE else on + + return source_key, join_key, remaining_condition |