diff options
Diffstat (limited to 'sqlglot/optimizer/optimize_joins.py')
-rw-r--r-- | sqlglot/optimizer/optimize_joins.py | 33 |
1 files changed, 16 insertions, 17 deletions
diff --git a/sqlglot/optimizer/optimize_joins.py b/sqlglot/optimizer/optimize_joins.py index 4e0c3a1..d51276f 100644 --- a/sqlglot/optimizer/optimize_joins.py +++ b/sqlglot/optimizer/optimize_joins.py @@ -1,3 +1,7 @@ +from __future__ import annotations + +import typing as t + from sqlglot import exp from sqlglot.helper import tsort @@ -13,25 +17,28 @@ def optimize_joins(expression): >>> optimize_joins(parse_one("SELECT * FROM x CROSS JOIN y JOIN z ON x.a = z.a AND y.a = z.a")).sql() 'SELECT * FROM x JOIN z ON x.a = z.a AND TRUE JOIN y ON y.a = z.a' """ + for select in expression.find_all(exp.Select): references = {} cross_joins = [] for join in select.args.get("joins", []): - name = join.this.alias_or_name - tables = other_table_names(join, name) + tables = other_table_names(join) if tables: for table in tables: references[table] = references.get(table, []) + [join] else: - cross_joins.append((name, join)) + cross_joins.append((join.alias_or_name, join)) for name, join in cross_joins: for dep in references.get(name, []): on = dep.args["on"] if isinstance(on, exp.Connector): + if len(other_table_names(dep)) < 2: + continue + for predicate in on.flatten(): if name in exp.column_table_names(predicate): predicate.replace(exp.true()) @@ -47,17 +54,12 @@ def reorder_joins(expression): Reorder joins by topological sort order based on predicate references. """ for from_ in expression.find_all(exp.From): - head = from_.this parent = from_.parent - joins = {join.this.alias_or_name: join for join in parent.args.get("joins", [])} - dag = {head.alias_or_name: []} - - for name, join in joins.items(): - dag[name] = other_table_names(join, name) - + joins = {join.alias_or_name: join for join in parent.args.get("joins", [])} + dag = {name: other_table_names(join) for name, join in joins.items()} parent.set( "joins", - [joins[name] for name in tsort(dag) if name != head.alias_or_name], + [joins[name] for name in tsort(dag) if name != from_.alias_or_name], ) return expression @@ -75,9 +77,6 @@ def normalize(expression): return expression -def other_table_names(join, exclude): - return [ - name - for name in (exp.column_table_names(join.args.get("on") or exp.true())) - if name != exclude - ] +def other_table_names(join: exp.Join) -> t.Set[str]: + on = join.args.get("on") + return exp.column_table_names(on, join.alias_or_name) if on else set() |