summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/optimize_joins.py
blob: 0c74e36b21289eaa3700c21da274c3cb5e7fe76d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from sqlglot import exp
from sqlglot.helper import tsort
from sqlglot.optimizer.simplify import simplify


def optimize_joins(expression):
    """
    Removes cross joins if possible and reorder joins based on predicate dependencies.
    """
    for select in expression.find_all(exp.Select):
        references = {}
        cross_joins = []

        for join in select.args.get("joins", []):
            name = join.this.alias_or_name
            tables = other_table_names(join, name)

            if tables:
                for table in tables:
                    references[table] = references.get(table, []) + [join]
            else:
                cross_joins.append((name, join))

        for name, join in cross_joins:
            for dep in references.get(name, []):
                on = dep.args["on"]
                on = on.replace(simplify(on))

                if isinstance(on, exp.Connector):
                    for predicate in on.flatten():
                        if name in exp.column_table_names(predicate):
                            predicate.replace(exp.TRUE)
                            join.on(predicate, copy=False)

    expression = reorder_joins(expression)
    expression = normalize(expression)
    return expression


def reorder_joins(expression):
    """
    Reorder joins by topological sort order based on predicate references.
    """
    for from_ in expression.find_all(exp.From):
        head = from_.expressions[0]
        parent = from_.parent
        joins = {join.this.alias_or_name: join for join in parent.args.get("joins", [])}
        dag = {head.alias_or_name: []}

        for name, join in joins.items():
            dag[name] = other_table_names(join, name)

        parent.set(
            "joins",
            [joins[name] for name in tsort(dag) if name != head.alias_or_name],
        )
    return expression


def normalize(expression):
    """
    Remove INNER and OUTER from joins as they are optional.
    """
    for join in expression.find_all(exp.Join):
        if join.kind != "CROSS":
            join.set("kind", None)
    return expression


def other_table_names(join, exclude):
    return [name for name in (exp.column_table_names(join.args.get("on") or exp.TRUE)) if name != exclude]