diff options
Diffstat (limited to 'sqlglot/optimizer/eliminate_joins.py')
-rw-r--r-- | sqlglot/optimizer/eliminate_joins.py | 49 |
1 files changed, 36 insertions, 13 deletions
diff --git a/sqlglot/optimizer/eliminate_joins.py b/sqlglot/optimizer/eliminate_joins.py index de4e011..3b40710 100644 --- a/sqlglot/optimizer/eliminate_joins.py +++ b/sqlglot/optimizer/eliminate_joins.py @@ -129,10 +129,23 @@ def join_condition(join): """ name = join.this.alias_or_name on = (join.args.get("on") or exp.true()).copy() - on = on if isinstance(on, exp.And) else exp.and_(on, exp.true()) source_key = [] join_key = [] + def extract_condition(condition): + left, right = condition.unnest_operands() + left_tables = exp.column_table_names(left) + right_tables = exp.column_table_names(right) + + if name in left_tables and name not in right_tables: + join_key.append(left) + source_key.append(right) + condition.replace(exp.true()) + elif name in right_tables and name not in left_tables: + join_key.append(right) + source_key.append(left) + condition.replace(exp.true()) + # find the join keys # SELECT # FROM x @@ -141,20 +154,30 @@ def join_condition(join): # # should pull y.b as the join key and x.a as the source key if normalized(on): + on = on if isinstance(on, exp.And) else exp.and_(on, exp.true()) + for condition in on.flatten(): if isinstance(condition, exp.EQ): - left, right = condition.unnest_operands() - left_tables = exp.column_table_names(left) - right_tables = exp.column_table_names(right) - - if name in left_tables and name not in right_tables: - join_key.append(left) - source_key.append(right) - condition.replace(exp.true()) - elif name in right_tables and name not in left_tables: - join_key.append(right) - source_key.append(left) - condition.replace(exp.true()) + extract_condition(condition) + elif normalized(on, dnf=True): + conditions = None + + for condition in on.flatten(): + parts = [part for part in condition.flatten() if isinstance(part, exp.EQ)] + if conditions is None: + conditions = parts + else: + temp = [] + for p in parts: + cs = [c for c in conditions if p == c] + + if cs: + temp.append(p) + temp.extend(cs) + conditions = temp + + for condition in conditions: + extract_condition(condition) on = simplify(on) remaining_condition = None if on == exp.true() else on |