summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/optimize_joins.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer/optimize_joins.py')
-rw-r--r--sqlglot/optimizer/optimize_joins.py33
1 files changed, 16 insertions, 17 deletions
diff --git a/sqlglot/optimizer/optimize_joins.py b/sqlglot/optimizer/optimize_joins.py
index 4e0c3a1..d51276f 100644
--- a/sqlglot/optimizer/optimize_joins.py
+++ b/sqlglot/optimizer/optimize_joins.py
@@ -1,3 +1,7 @@
+from __future__ import annotations
+
+import typing as t
+
from sqlglot import exp
from sqlglot.helper import tsort
@@ -13,25 +17,28 @@ def optimize_joins(expression):
>>> optimize_joins(parse_one("SELECT * FROM x CROSS JOIN y JOIN z ON x.a = z.a AND y.a = z.a")).sql()
'SELECT * FROM x JOIN z ON x.a = z.a AND TRUE JOIN y ON y.a = z.a'
"""
+
for select in expression.find_all(exp.Select):
references = {}
cross_joins = []
for join in select.args.get("joins", []):
- name = join.this.alias_or_name
- tables = other_table_names(join, name)
+ tables = other_table_names(join)
if tables:
for table in tables:
references[table] = references.get(table, []) + [join]
else:
- cross_joins.append((name, join))
+ cross_joins.append((join.alias_or_name, join))
for name, join in cross_joins:
for dep in references.get(name, []):
on = dep.args["on"]
if isinstance(on, exp.Connector):
+ if len(other_table_names(dep)) < 2:
+ continue
+
for predicate in on.flatten():
if name in exp.column_table_names(predicate):
predicate.replace(exp.true())
@@ -47,17 +54,12 @@ def reorder_joins(expression):
Reorder joins by topological sort order based on predicate references.
"""
for from_ in expression.find_all(exp.From):
- head = from_.this
parent = from_.parent
- joins = {join.this.alias_or_name: join for join in parent.args.get("joins", [])}
- dag = {head.alias_or_name: []}
-
- for name, join in joins.items():
- dag[name] = other_table_names(join, name)
-
+ joins = {join.alias_or_name: join for join in parent.args.get("joins", [])}
+ dag = {name: other_table_names(join) for name, join in joins.items()}
parent.set(
"joins",
- [joins[name] for name in tsort(dag) if name != head.alias_or_name],
+ [joins[name] for name in tsort(dag) if name != from_.alias_or_name],
)
return expression
@@ -75,9 +77,6 @@ def normalize(expression):
return expression
-def other_table_names(join, exclude):
- return [
- name
- for name in (exp.column_table_names(join.args.get("on") or exp.true()))
- if name != exclude
- ]
+def other_table_names(join: exp.Join) -> t.Set[str]:
+ on = join.args.get("on")
+ return exp.column_table_names(on, join.alias_or_name) if on else set()