sqlglot.optimizer.eliminate_joins
1from sqlglot import expressions as exp 2from sqlglot.optimizer.normalize import normalized 3from sqlglot.optimizer.scope import Scope, traverse_scope 4from sqlglot.optimizer.simplify import simplify 5 6 7def eliminate_joins(expression): 8 """ 9 Remove unused joins from an expression. 10 11 This only removes joins when we know that the join condition doesn't produce duplicate rows. 12 13 Example: 14 >>> import sqlglot 15 >>> sql = "SELECT x.a FROM x LEFT JOIN (SELECT DISTINCT y.b FROM y) AS y ON x.b = y.b" 16 >>> expression = sqlglot.parse_one(sql) 17 >>> eliminate_joins(expression).sql() 18 'SELECT x.a FROM x' 19 20 Args: 21 expression (sqlglot.Expression): expression to optimize 22 Returns: 23 sqlglot.Expression: optimized expression 24 """ 25 for scope in traverse_scope(expression): 26 # If any columns in this scope aren't qualified, it's hard to determine if a join isn't used. 27 # It's probably possible to infer this from the outputs of derived tables. 28 # But for now, let's just skip this rule. 29 if scope.unqualified_columns: 30 continue 31 32 joins = scope.expression.args.get("joins", []) 33 34 # Reverse the joins so we can remove chains of unused joins 35 for join in reversed(joins): 36 alias = join.this.alias_or_name 37 if _should_eliminate_join(scope, join, alias): 38 join.pop() 39 scope.remove_source(alias) 40 return expression 41 42 43def _should_eliminate_join(scope, join, alias): 44 inner_source = scope.sources.get(alias) 45 return ( 46 isinstance(inner_source, Scope) 47 and not _join_is_used(scope, join, alias) 48 and ( 49 (join.side == "LEFT" and _is_joined_on_all_unique_outputs(inner_source, join)) 50 or (not join.args.get("on") and _has_single_output_row(inner_source)) 51 ) 52 ) 53 54 55def _join_is_used(scope, join, alias): 56 # We need to find all columns that reference this join. 57 # But columns in the ON clause shouldn't count. 58 on = join.args.get("on") 59 if on: 60 on_clause_columns = {id(column) for column in on.find_all(exp.Column)} 61 else: 62 on_clause_columns = set() 63 return any( 64 column for column in scope.source_columns(alias) if id(column) not in on_clause_columns 65 ) 66 67 68def _is_joined_on_all_unique_outputs(scope, join): 69 unique_outputs = _unique_outputs(scope) 70 if not unique_outputs: 71 return False 72 73 _, join_keys, _ = join_condition(join) 74 remaining_unique_outputs = unique_outputs - {c.name for c in join_keys} 75 return not remaining_unique_outputs 76 77 78def _unique_outputs(scope): 79 """Determine output columns of `scope` that must have a unique combination per row""" 80 if scope.expression.args.get("distinct"): 81 return set(scope.expression.named_selects) 82 83 group = scope.expression.args.get("group") 84 if group: 85 grouped_expressions = set(group.expressions) 86 grouped_outputs = set() 87 88 unique_outputs = set() 89 for select in scope.selects: 90 output = select.unalias() 91 if output in grouped_expressions: 92 grouped_outputs.add(output) 93 unique_outputs.add(select.alias_or_name) 94 95 # All the grouped expressions must be in the output 96 if not grouped_expressions.difference(grouped_outputs): 97 return unique_outputs 98 else: 99 return set() 100 101 if _has_single_output_row(scope): 102 return set(scope.expression.named_selects) 103 104 return set() 105 106 107def _has_single_output_row(scope): 108 return isinstance(scope.expression, exp.Select) and ( 109 all(isinstance(e.unalias(), exp.AggFunc) for e in scope.selects) 110 or _is_limit_1(scope) 111 or not scope.expression.args.get("from") 112 ) 113 114 115def _is_limit_1(scope): 116 limit = scope.expression.args.get("limit") 117 return limit and limit.expression.this == "1" 118 119 120def join_condition(join): 121 """ 122 Extract the join condition from a join expression. 123 124 Args: 125 join (exp.Join) 126 Returns: 127 tuple[list[str], list[str], exp.Expression]: 128 Tuple of (source key, join key, remaining predicate) 129 """ 130 name = join.this.alias_or_name 131 on = (join.args.get("on") or exp.true()).copy() 132 source_key = [] 133 join_key = [] 134 135 def extract_condition(condition): 136 left, right = condition.unnest_operands() 137 left_tables = exp.column_table_names(left) 138 right_tables = exp.column_table_names(right) 139 140 if name in left_tables and name not in right_tables: 141 join_key.append(left) 142 source_key.append(right) 143 condition.replace(exp.true()) 144 elif name in right_tables and name not in left_tables: 145 join_key.append(right) 146 source_key.append(left) 147 condition.replace(exp.true()) 148 149 # find the join keys 150 # SELECT 151 # FROM x 152 # JOIN y 153 # ON x.a = y.b AND y.b > 1 154 # 155 # should pull y.b as the join key and x.a as the source key 156 if normalized(on): 157 on = on if isinstance(on, exp.And) else exp.and_(on, exp.true()) 158 159 for condition in on.flatten(): 160 if isinstance(condition, exp.EQ): 161 extract_condition(condition) 162 elif normalized(on, dnf=True): 163 conditions = None 164 165 for condition in on.flatten(): 166 parts = [part for part in condition.flatten() if isinstance(part, exp.EQ)] 167 if conditions is None: 168 conditions = parts 169 else: 170 temp = [] 171 for p in parts: 172 cs = [c for c in conditions if p == c] 173 174 if cs: 175 temp.append(p) 176 temp.extend(cs) 177 conditions = temp 178 179 for condition in conditions: 180 extract_condition(condition) 181 182 on = simplify(on) 183 remaining_condition = None if on == exp.true() else on 184 return source_key, join_key, remaining_condition
def
eliminate_joins(expression):
8def eliminate_joins(expression): 9 """ 10 Remove unused joins from an expression. 11 12 This only removes joins when we know that the join condition doesn't produce duplicate rows. 13 14 Example: 15 >>> import sqlglot 16 >>> sql = "SELECT x.a FROM x LEFT JOIN (SELECT DISTINCT y.b FROM y) AS y ON x.b = y.b" 17 >>> expression = sqlglot.parse_one(sql) 18 >>> eliminate_joins(expression).sql() 19 'SELECT x.a FROM x' 20 21 Args: 22 expression (sqlglot.Expression): expression to optimize 23 Returns: 24 sqlglot.Expression: optimized expression 25 """ 26 for scope in traverse_scope(expression): 27 # If any columns in this scope aren't qualified, it's hard to determine if a join isn't used. 28 # It's probably possible to infer this from the outputs of derived tables. 29 # But for now, let's just skip this rule. 30 if scope.unqualified_columns: 31 continue 32 33 joins = scope.expression.args.get("joins", []) 34 35 # Reverse the joins so we can remove chains of unused joins 36 for join in reversed(joins): 37 alias = join.this.alias_or_name 38 if _should_eliminate_join(scope, join, alias): 39 join.pop() 40 scope.remove_source(alias) 41 return expression
Remove unused joins from an expression.
This only removes joins when we know that the join condition doesn't produce duplicate rows.
Example:
>>> import sqlglot >>> sql = "SELECT x.a FROM x LEFT JOIN (SELECT DISTINCT y.b FROM y) AS y ON x.b = y.b" >>> expression = sqlglot.parse_one(sql) >>> eliminate_joins(expression).sql() 'SELECT x.a FROM x'
Arguments:
- expression (sqlglot.Expression): expression to optimize
Returns:
sqlglot.Expression: optimized expression
def
join_condition(join):
121def join_condition(join): 122 """ 123 Extract the join condition from a join expression. 124 125 Args: 126 join (exp.Join) 127 Returns: 128 tuple[list[str], list[str], exp.Expression]: 129 Tuple of (source key, join key, remaining predicate) 130 """ 131 name = join.this.alias_or_name 132 on = (join.args.get("on") or exp.true()).copy() 133 source_key = [] 134 join_key = [] 135 136 def extract_condition(condition): 137 left, right = condition.unnest_operands() 138 left_tables = exp.column_table_names(left) 139 right_tables = exp.column_table_names(right) 140 141 if name in left_tables and name not in right_tables: 142 join_key.append(left) 143 source_key.append(right) 144 condition.replace(exp.true()) 145 elif name in right_tables and name not in left_tables: 146 join_key.append(right) 147 source_key.append(left) 148 condition.replace(exp.true()) 149 150 # find the join keys 151 # SELECT 152 # FROM x 153 # JOIN y 154 # ON x.a = y.b AND y.b > 1 155 # 156 # should pull y.b as the join key and x.a as the source key 157 if normalized(on): 158 on = on if isinstance(on, exp.And) else exp.and_(on, exp.true()) 159 160 for condition in on.flatten(): 161 if isinstance(condition, exp.EQ): 162 extract_condition(condition) 163 elif normalized(on, dnf=True): 164 conditions = None 165 166 for condition in on.flatten(): 167 parts = [part for part in condition.flatten() if isinstance(part, exp.EQ)] 168 if conditions is None: 169 conditions = parts 170 else: 171 temp = [] 172 for p in parts: 173 cs = [c for c in conditions if p == c] 174 175 if cs: 176 temp.append(p) 177 temp.extend(cs) 178 conditions = temp 179 180 for condition in conditions: 181 extract_condition(condition) 182 183 on = simplify(on) 184 remaining_condition = None if on == exp.true() else on 185 return source_key, join_key, remaining_condition
Extract the join condition from a join expression.
Arguments:
- join (exp.Join)
Returns:
tuple[list[str], list[str], exp.Expression]: Tuple of (source key, join key, remaining predicate)