summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/eliminate_joins.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer/eliminate_joins.py')
-rw-r--r--sqlglot/optimizer/eliminate_joins.py49
1 files changed, 36 insertions, 13 deletions
diff --git a/sqlglot/optimizer/eliminate_joins.py b/sqlglot/optimizer/eliminate_joins.py
index de4e011..3b40710 100644
--- a/sqlglot/optimizer/eliminate_joins.py
+++ b/sqlglot/optimizer/eliminate_joins.py
@@ -129,10 +129,23 @@ def join_condition(join):
"""
name = join.this.alias_or_name
on = (join.args.get("on") or exp.true()).copy()
- on = on if isinstance(on, exp.And) else exp.and_(on, exp.true())
source_key = []
join_key = []
+ def extract_condition(condition):
+ left, right = condition.unnest_operands()
+ left_tables = exp.column_table_names(left)
+ right_tables = exp.column_table_names(right)
+
+ if name in left_tables and name not in right_tables:
+ join_key.append(left)
+ source_key.append(right)
+ condition.replace(exp.true())
+ elif name in right_tables and name not in left_tables:
+ join_key.append(right)
+ source_key.append(left)
+ condition.replace(exp.true())
+
# find the join keys
# SELECT
# FROM x
@@ -141,20 +154,30 @@ def join_condition(join):
#
# should pull y.b as the join key and x.a as the source key
if normalized(on):
+ on = on if isinstance(on, exp.And) else exp.and_(on, exp.true())
+
for condition in on.flatten():
if isinstance(condition, exp.EQ):
- left, right = condition.unnest_operands()
- left_tables = exp.column_table_names(left)
- right_tables = exp.column_table_names(right)
-
- if name in left_tables and name not in right_tables:
- join_key.append(left)
- source_key.append(right)
- condition.replace(exp.true())
- elif name in right_tables and name not in left_tables:
- join_key.append(right)
- source_key.append(left)
- condition.replace(exp.true())
+ extract_condition(condition)
+ elif normalized(on, dnf=True):
+ conditions = None
+
+ for condition in on.flatten():
+ parts = [part for part in condition.flatten() if isinstance(part, exp.EQ)]
+ if conditions is None:
+ conditions = parts
+ else:
+ temp = []
+ for p in parts:
+ cs = [c for c in conditions if p == c]
+
+ if cs:
+ temp.append(p)
+ temp.extend(cs)
+ conditions = temp
+
+ for condition in conditions:
+ extract_condition(condition)
on = simplify(on)
remaining_condition = None if on == exp.true() else on