Edit on GitHub

sqlglot.optimizer.optimize_joins

 1from __future__ import annotations
 2
 3import typing as t
 4
 5from sqlglot import exp
 6from sqlglot.helper import tsort
 7
 8JOIN_ATTRS = ("on", "side", "kind", "using", "method")
 9
10
11def optimize_joins(expression):
12    """
13    Removes cross joins if possible and reorder joins based on predicate dependencies.
14
15    Example:
16        >>> from sqlglot import parse_one
17        >>> optimize_joins(parse_one("SELECT * FROM x CROSS JOIN y JOIN z ON x.a = z.a AND y.a = z.a")).sql()
18        'SELECT * FROM x JOIN z ON x.a = z.a AND TRUE JOIN y ON y.a = z.a'
19    """
20
21    for select in expression.find_all(exp.Select):
22        references = {}
23        cross_joins = []
24
25        for join in select.args.get("joins", []):
26            tables = other_table_names(join)
27
28            if tables:
29                for table in tables:
30                    references[table] = references.get(table, []) + [join]
31            else:
32                cross_joins.append((join.alias_or_name, join))
33
34        for name, join in cross_joins:
35            for dep in references.get(name, []):
36                on = dep.args["on"]
37
38                if isinstance(on, exp.Connector):
39                    if len(other_table_names(dep)) < 2:
40                        continue
41
42                    for predicate in on.flatten():
43                        if name in exp.column_table_names(predicate):
44                            predicate.replace(exp.true())
45                            join.on(predicate, copy=False)
46
47    expression = reorder_joins(expression)
48    expression = normalize(expression)
49    return expression
50
51
52def reorder_joins(expression):
53    """
54    Reorder joins by topological sort order based on predicate references.
55    """
56    for from_ in expression.find_all(exp.From):
57        parent = from_.parent
58        joins = {join.alias_or_name: join for join in parent.args.get("joins", [])}
59        dag = {name: other_table_names(join) for name, join in joins.items()}
60        parent.set(
61            "joins",
62            [joins[name] for name in tsort(dag) if name != from_.alias_or_name and name in joins],
63        )
64    return expression
65
66
67def normalize(expression):
68    """
69    Remove INNER and OUTER from joins as they are optional.
70    """
71    for join in expression.find_all(exp.Join):
72        if not any(join.args.get(k) for k in JOIN_ATTRS):
73            join.set("kind", "CROSS")
74
75        if join.kind == "CROSS":
76            join.set("on", None)
77        else:
78            join.set("kind", None)
79
80            if not join.args.get("on") and not join.args.get("using"):
81                join.set("on", exp.true())
82    return expression
83
84
85def other_table_names(join: exp.Join) -> t.Set[str]:
86    on = join.args.get("on")
87    return exp.column_table_names(on, join.alias_or_name) if on else set()
JOIN_ATTRS = ('on', 'side', 'kind', 'using', 'method')
def optimize_joins(expression):
12def optimize_joins(expression):
13    """
14    Removes cross joins if possible and reorder joins based on predicate dependencies.
15
16    Example:
17        >>> from sqlglot import parse_one
18        >>> optimize_joins(parse_one("SELECT * FROM x CROSS JOIN y JOIN z ON x.a = z.a AND y.a = z.a")).sql()
19        'SELECT * FROM x JOIN z ON x.a = z.a AND TRUE JOIN y ON y.a = z.a'
20    """
21
22    for select in expression.find_all(exp.Select):
23        references = {}
24        cross_joins = []
25
26        for join in select.args.get("joins", []):
27            tables = other_table_names(join)
28
29            if tables:
30                for table in tables:
31                    references[table] = references.get(table, []) + [join]
32            else:
33                cross_joins.append((join.alias_or_name, join))
34
35        for name, join in cross_joins:
36            for dep in references.get(name, []):
37                on = dep.args["on"]
38
39                if isinstance(on, exp.Connector):
40                    if len(other_table_names(dep)) < 2:
41                        continue
42
43                    for predicate in on.flatten():
44                        if name in exp.column_table_names(predicate):
45                            predicate.replace(exp.true())
46                            join.on(predicate, copy=False)
47
48    expression = reorder_joins(expression)
49    expression = normalize(expression)
50    return expression

Removes cross joins if possible and reorder joins based on predicate dependencies.

Example:
>>> from sqlglot import parse_one
>>> optimize_joins(parse_one("SELECT * FROM x CROSS JOIN y JOIN z ON x.a = z.a AND y.a = z.a")).sql()
'SELECT * FROM x JOIN z ON x.a = z.a AND TRUE JOIN y ON y.a = z.a'
def reorder_joins(expression):
53def reorder_joins(expression):
54    """
55    Reorder joins by topological sort order based on predicate references.
56    """
57    for from_ in expression.find_all(exp.From):
58        parent = from_.parent
59        joins = {join.alias_or_name: join for join in parent.args.get("joins", [])}
60        dag = {name: other_table_names(join) for name, join in joins.items()}
61        parent.set(
62            "joins",
63            [joins[name] for name in tsort(dag) if name != from_.alias_or_name and name in joins],
64        )
65    return expression

Reorder joins by topological sort order based on predicate references.

def normalize(expression):
68def normalize(expression):
69    """
70    Remove INNER and OUTER from joins as they are optional.
71    """
72    for join in expression.find_all(exp.Join):
73        if not any(join.args.get(k) for k in JOIN_ATTRS):
74            join.set("kind", "CROSS")
75
76        if join.kind == "CROSS":
77            join.set("on", None)
78        else:
79            join.set("kind", None)
80
81            if not join.args.get("on") and not join.args.get("using"):
82                join.set("on", exp.true())
83    return expression

Remove INNER and OUTER from joins as they are optional.

def other_table_names(join: sqlglot.expressions.Join) -> Set[str]:
86def other_table_names(join: exp.Join) -> t.Set[str]:
87    on = join.args.get("on")
88    return exp.column_table_names(on, join.alias_or_name) if on else set()