Edit on GitHub

sqlglot.optimizer.optimize_joins

 1from sqlglot import exp
 2from sqlglot.helper import tsort
 3
 4
 5def optimize_joins(expression):
 6    """
 7    Removes cross joins if possible and reorder joins based on predicate dependencies.
 8
 9    Example:
10        >>> from sqlglot import parse_one
11        >>> optimize_joins(parse_one("SELECT * FROM x CROSS JOIN y JOIN z ON x.a = z.a AND y.a = z.a")).sql()
12        'SELECT * FROM x JOIN z ON x.a = z.a AND TRUE JOIN y ON y.a = z.a'
13    """
14    for select in expression.find_all(exp.Select):
15        references = {}
16        cross_joins = []
17
18        for join in select.args.get("joins", []):
19            name = join.this.alias_or_name
20            tables = other_table_names(join, name)
21
22            if tables:
23                for table in tables:
24                    references[table] = references.get(table, []) + [join]
25            else:
26                cross_joins.append((name, join))
27
28        for name, join in cross_joins:
29            for dep in references.get(name, []):
30                on = dep.args["on"]
31
32                if isinstance(on, exp.Connector):
33                    for predicate in on.flatten():
34                        if name in exp.column_table_names(predicate):
35                            predicate.replace(exp.true())
36                            join.on(predicate, copy=False)
37
38    expression = reorder_joins(expression)
39    expression = normalize(expression)
40    return expression
41
42
43def reorder_joins(expression):
44    """
45    Reorder joins by topological sort order based on predicate references.
46    """
47    for from_ in expression.find_all(exp.From):
48        head = from_.expressions[0]
49        parent = from_.parent
50        joins = {join.this.alias_or_name: join for join in parent.args.get("joins", [])}
51        dag = {head.alias_or_name: []}
52
53        for name, join in joins.items():
54            dag[name] = other_table_names(join, name)
55
56        parent.set(
57            "joins",
58            [joins[name] for name in tsort(dag) if name != head.alias_or_name],
59        )
60    return expression
61
62
63def normalize(expression):
64    """
65    Remove INNER and OUTER from joins as they are optional.
66    """
67    for join in expression.find_all(exp.Join):
68        if join.kind != "CROSS":
69            join.set("kind", None)
70    return expression
71
72
73def other_table_names(join, exclude):
74    return [
75        name
76        for name in (exp.column_table_names(join.args.get("on") or exp.true()))
77        if name != exclude
78    ]
def optimize_joins(expression):
 6def optimize_joins(expression):
 7    """
 8    Removes cross joins if possible and reorder joins based on predicate dependencies.
 9
10    Example:
11        >>> from sqlglot import parse_one
12        >>> optimize_joins(parse_one("SELECT * FROM x CROSS JOIN y JOIN z ON x.a = z.a AND y.a = z.a")).sql()
13        'SELECT * FROM x JOIN z ON x.a = z.a AND TRUE JOIN y ON y.a = z.a'
14    """
15    for select in expression.find_all(exp.Select):
16        references = {}
17        cross_joins = []
18
19        for join in select.args.get("joins", []):
20            name = join.this.alias_or_name
21            tables = other_table_names(join, name)
22
23            if tables:
24                for table in tables:
25                    references[table] = references.get(table, []) + [join]
26            else:
27                cross_joins.append((name, join))
28
29        for name, join in cross_joins:
30            for dep in references.get(name, []):
31                on = dep.args["on"]
32
33                if isinstance(on, exp.Connector):
34                    for predicate in on.flatten():
35                        if name in exp.column_table_names(predicate):
36                            predicate.replace(exp.true())
37                            join.on(predicate, copy=False)
38
39    expression = reorder_joins(expression)
40    expression = normalize(expression)
41    return expression

Removes cross joins if possible and reorder joins based on predicate dependencies.

Example:
>>> from sqlglot import parse_one
>>> optimize_joins(parse_one("SELECT * FROM x CROSS JOIN y JOIN z ON x.a = z.a AND y.a = z.a")).sql()
'SELECT * FROM x JOIN z ON x.a = z.a AND TRUE JOIN y ON y.a = z.a'
def reorder_joins(expression):
44def reorder_joins(expression):
45    """
46    Reorder joins by topological sort order based on predicate references.
47    """
48    for from_ in expression.find_all(exp.From):
49        head = from_.expressions[0]
50        parent = from_.parent
51        joins = {join.this.alias_or_name: join for join in parent.args.get("joins", [])}
52        dag = {head.alias_or_name: []}
53
54        for name, join in joins.items():
55            dag[name] = other_table_names(join, name)
56
57        parent.set(
58            "joins",
59            [joins[name] for name in tsort(dag) if name != head.alias_or_name],
60        )
61    return expression

Reorder joins by topological sort order based on predicate references.

def normalize(expression):
64def normalize(expression):
65    """
66    Remove INNER and OUTER from joins as they are optional.
67    """
68    for join in expression.find_all(exp.Join):
69        if join.kind != "CROSS":
70            join.set("kind", None)
71    return expression

Remove INNER and OUTER from joins as they are optional.

def other_table_names(join, exclude):
74def other_table_names(join, exclude):
75    return [
76        name
77        for name in (exp.column_table_names(join.args.get("on") or exp.true()))
78        if name != exclude
79    ]