summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/pushdown_predicates.py
blob: e757322e3eb4982ba11bb98838878d11a24f3535 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
from sqlglot import exp
from sqlglot.optimizer.normalize import normalized
from sqlglot.optimizer.scope import traverse_scope
from sqlglot.optimizer.simplify import simplify


def pushdown_predicates(expression):
    """
    Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS

    Example:
        >>> import sqlglot
        >>> sql = "SELECT * FROM (SELECT * FROM x AS x) AS y WHERE y.a = 1"
        >>> expression = sqlglot.parse_one(sql)
        >>> pushdown_predicates(expression).sql()
        'SELECT * FROM (SELECT * FROM x AS x WHERE y.a = 1) AS y WHERE TRUE'

    Args:
        expression (sqlglot.Expression): expression to optimize
    Returns:
        sqlglot.Expression: optimized expression
    """
    for scope in reversed(traverse_scope(expression)):
        select = scope.expression
        where = select.args.get("where")
        if where:
            pushdown(where.this, scope.selected_sources)

        # joins should only pushdown into itself, not to other joins
        # so we limit the selected sources to only itself
        for join in select.args.get("joins") or []:
            name = join.this.alias_or_name
            pushdown(join.args.get("on"), {name: scope.selected_sources[name]})

    return expression


def pushdown(condition, sources):
    if not condition:
        return

    condition = condition.replace(simplify(condition))
    cnf_like = normalized(condition) or not normalized(condition, dnf=True)

    predicates = list(
        condition.flatten()
        if isinstance(condition, exp.And if cnf_like else exp.Or)
        else [condition]
    )

    if cnf_like:
        pushdown_cnf(predicates, sources)
    else:
        pushdown_dnf(predicates, sources)


def pushdown_cnf(predicates, scope):
    """
    If the predicates are in CNF like form, we can simply replace each block in the parent.
    """
    for predicate in predicates:
        for node in nodes_for_predicate(predicate, scope).values():
            if isinstance(node, exp.Join):
                predicate.replace(exp.TRUE)
                node.on(predicate, copy=False)
                break
            if isinstance(node, exp.Select):
                predicate.replace(exp.TRUE)
                node.where(replace_aliases(node, predicate), copy=False)


def pushdown_dnf(predicates, scope):
    """
    If the predicates are in DNF form, we can only push down conditions that are in all blocks.
    Additionally, we can't remove predicates from their original form.
    """
    # find all the tables that can be pushdown too
    # these are tables that are referenced in all blocks of a DNF
    # (a.x AND b.x) OR (a.y AND c.y)
    # only table a can be push down
    pushdown_tables = set()

    for a in predicates:
        a_tables = set(exp.column_table_names(a))

        for b in predicates:
            a_tables &= set(exp.column_table_names(b))

        pushdown_tables.update(a_tables)

    conditions = {}

    # for every pushdown table, find all related conditions in all predicates
    # combine them with ORS
    # (a.x AND and a.y AND b.x) OR (a.z AND c.y) -> (a.x AND a.y) OR (a.z)
    for table in sorted(pushdown_tables):
        for predicate in predicates:
            nodes = nodes_for_predicate(predicate, scope)

            if table not in nodes:
                continue

            predicate_condition = None

            for column in predicate.find_all(exp.Column):
                if column.table == table:
                    condition = column.find_ancestor(exp.Condition)
                    predicate_condition = (
                        exp.and_(predicate_condition, condition)
                        if predicate_condition
                        else condition
                    )

            if predicate_condition:
                conditions[table] = (
                    exp.or_(conditions[table], predicate_condition)
                    if table in conditions
                    else predicate_condition
                )

        for name, node in nodes.items():
            if name not in conditions:
                continue

            predicate = conditions[name]

            if isinstance(node, exp.Join):
                node.on(predicate, copy=False)
            elif isinstance(node, exp.Select):
                node.where(replace_aliases(node, predicate), copy=False)


def nodes_for_predicate(predicate, sources):
    nodes = {}
    tables = exp.column_table_names(predicate)
    where_condition = isinstance(
        predicate.find_ancestor(exp.Join, exp.Where), exp.Where
    )

    for table in tables:
        node, source = sources.get(table) or (None, None)

        # if the predicate is in a where statement we can try to push it down
        # we want to find the root join or from statement
        if node and where_condition:
            node = node.find_ancestor(exp.Join, exp.From)

        # a node can reference a CTE which should be push down
        if isinstance(node, exp.From) and not isinstance(source, exp.Table):
            node = source.expression

        if isinstance(node, exp.Join):
            if node.side:
                return {}
            nodes[table] = node
        elif isinstance(node, exp.Select) and len(tables) == 1:
            if not node.args.get("group"):
                nodes[table] = node
    return nodes


def replace_aliases(source, predicate):
    aliases = {}

    for select in source.selects:
        if isinstance(select, exp.Alias):
            aliases[select.alias] = select.this
        else:
            aliases[select.name] = select

    def _replace_alias(column):
        if isinstance(column, exp.Column) and column.name in aliases:
            return aliases[column.name]
        return column

    return predicate.transform(_replace_alias)