Edit on GitHub

sqlglot.optimizer.optimize_joins

 1from sqlglot import exp
 2from sqlglot.helper import tsort
 3
 4JOIN_ATTRS = ("on", "side", "kind", "using", "natural")
 5
 6
 7def optimize_joins(expression):
 8    """
 9    Removes cross joins if possible and reorder joins based on predicate dependencies.
10
11    Example:
12        >>> from sqlglot import parse_one
13        >>> optimize_joins(parse_one("SELECT * FROM x CROSS JOIN y JOIN z ON x.a = z.a AND y.a = z.a")).sql()
14        'SELECT * FROM x JOIN z ON x.a = z.a AND TRUE JOIN y ON y.a = z.a'
15    """
16    for select in expression.find_all(exp.Select):
17        references = {}
18        cross_joins = []
19
20        for join in select.args.get("joins", []):
21            name = join.this.alias_or_name
22            tables = other_table_names(join, name)
23
24            if tables:
25                for table in tables:
26                    references[table] = references.get(table, []) + [join]
27            else:
28                cross_joins.append((name, join))
29
30        for name, join in cross_joins:
31            for dep in references.get(name, []):
32                on = dep.args["on"]
33
34                if isinstance(on, exp.Connector):
35                    for predicate in on.flatten():
36                        if name in exp.column_table_names(predicate):
37                            predicate.replace(exp.true())
38                            join.on(predicate, copy=False)
39
40    expression = reorder_joins(expression)
41    expression = normalize(expression)
42    return expression
43
44
45def reorder_joins(expression):
46    """
47    Reorder joins by topological sort order based on predicate references.
48    """
49    for from_ in expression.find_all(exp.From):
50        head = from_.this
51        parent = from_.parent
52        joins = {join.this.alias_or_name: join for join in parent.args.get("joins", [])}
53        dag = {head.alias_or_name: []}
54
55        for name, join in joins.items():
56            dag[name] = other_table_names(join, name)
57
58        parent.set(
59            "joins",
60            [joins[name] for name in tsort(dag) if name != head.alias_or_name],
61        )
62    return expression
63
64
65def normalize(expression):
66    """
67    Remove INNER and OUTER from joins as they are optional.
68    """
69    for join in expression.find_all(exp.Join):
70        if not any(join.args.get(k) for k in JOIN_ATTRS):
71            join.set("kind", "CROSS")
72
73        if join.kind != "CROSS":
74            join.set("kind", None)
75    return expression
76
77
78def other_table_names(join, exclude):
79    return [
80        name
81        for name in (exp.column_table_names(join.args.get("on") or exp.true()))
82        if name != exclude
83    ]
def optimize_joins(expression):
 8def optimize_joins(expression):
 9    """
10    Removes cross joins if possible and reorder joins based on predicate dependencies.
11
12    Example:
13        >>> from sqlglot import parse_one
14        >>> optimize_joins(parse_one("SELECT * FROM x CROSS JOIN y JOIN z ON x.a = z.a AND y.a = z.a")).sql()
15        'SELECT * FROM x JOIN z ON x.a = z.a AND TRUE JOIN y ON y.a = z.a'
16    """
17    for select in expression.find_all(exp.Select):
18        references = {}
19        cross_joins = []
20
21        for join in select.args.get("joins", []):
22            name = join.this.alias_or_name
23            tables = other_table_names(join, name)
24
25            if tables:
26                for table in tables:
27                    references[table] = references.get(table, []) + [join]
28            else:
29                cross_joins.append((name, join))
30
31        for name, join in cross_joins:
32            for dep in references.get(name, []):
33                on = dep.args["on"]
34
35                if isinstance(on, exp.Connector):
36                    for predicate in on.flatten():
37                        if name in exp.column_table_names(predicate):
38                            predicate.replace(exp.true())
39                            join.on(predicate, copy=False)
40
41    expression = reorder_joins(expression)
42    expression = normalize(expression)
43    return expression

Removes cross joins if possible and reorder joins based on predicate dependencies.

Example:
>>> from sqlglot import parse_one
>>> optimize_joins(parse_one("SELECT * FROM x CROSS JOIN y JOIN z ON x.a = z.a AND y.a = z.a")).sql()
'SELECT * FROM x JOIN z ON x.a = z.a AND TRUE JOIN y ON y.a = z.a'
def reorder_joins(expression):
46def reorder_joins(expression):
47    """
48    Reorder joins by topological sort order based on predicate references.
49    """
50    for from_ in expression.find_all(exp.From):
51        head = from_.this
52        parent = from_.parent
53        joins = {join.this.alias_or_name: join for join in parent.args.get("joins", [])}
54        dag = {head.alias_or_name: []}
55
56        for name, join in joins.items():
57            dag[name] = other_table_names(join, name)
58
59        parent.set(
60            "joins",
61            [joins[name] for name in tsort(dag) if name != head.alias_or_name],
62        )
63    return expression

Reorder joins by topological sort order based on predicate references.

def normalize(expression):
66def normalize(expression):
67    """
68    Remove INNER and OUTER from joins as they are optional.
69    """
70    for join in expression.find_all(exp.Join):
71        if not any(join.args.get(k) for k in JOIN_ATTRS):
72            join.set("kind", "CROSS")
73
74        if join.kind != "CROSS":
75            join.set("kind", None)
76    return expression

Remove INNER and OUTER from joins as they are optional.

def other_table_names(join, exclude):
79def other_table_names(join, exclude):
80    return [
81        name
82        for name in (exp.column_table_names(join.args.get("on") or exp.true()))
83        if name != exclude
84    ]