Edit on GitHub

sqlglot.optimizer.pushdown_predicates

  1from sqlglot import exp
  2from sqlglot.optimizer.normalize import normalized
  3from sqlglot.optimizer.scope import build_scope
  4from sqlglot.optimizer.simplify import simplify
  5
  6
  7def pushdown_predicates(expression):
  8    """
  9    Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS
 10
 11    Example:
 12        >>> import sqlglot
 13        >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y WHERE y.a = 1"
 14        >>> expression = sqlglot.parse_one(sql)
 15        >>> pushdown_predicates(expression).sql()
 16        'SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x WHERE x.a = 1) AS y WHERE TRUE'
 17
 18    Args:
 19        expression (sqlglot.Expression): expression to optimize
 20    Returns:
 21        sqlglot.Expression: optimized expression
 22    """
 23    root = build_scope(expression)
 24
 25    if root:
 26        scope_ref_count = root.ref_count()
 27
 28        for scope in reversed(list(root.traverse())):
 29            select = scope.expression
 30            where = select.args.get("where")
 31            if where:
 32                selected_sources = scope.selected_sources
 33                # a right join can only push down to itself and not the source FROM table
 34                for k, (node, source) in selected_sources.items():
 35                    parent = node.find_ancestor(exp.Join, exp.From)
 36                    if isinstance(parent, exp.Join) and parent.side == "RIGHT":
 37                        selected_sources = {k: (node, source)}
 38                        break
 39                pushdown(where.this, selected_sources, scope_ref_count)
 40
 41            # joins should only pushdown into itself, not to other joins
 42            # so we limit the selected sources to only itself
 43            for join in select.args.get("joins") or []:
 44                name = join.alias_or_name
 45                if name in scope.selected_sources:
 46                    pushdown(
 47                        join.args.get("on"), {name: scope.selected_sources[name]}, scope_ref_count
 48                    )
 49
 50    return expression
 51
 52
 53def pushdown(condition, sources, scope_ref_count):
 54    if not condition:
 55        return
 56
 57    condition = condition.replace(simplify(condition))
 58    cnf_like = normalized(condition) or not normalized(condition, dnf=True)
 59
 60    predicates = list(
 61        condition.flatten()
 62        if isinstance(condition, exp.And if cnf_like else exp.Or)
 63        else [condition]
 64    )
 65
 66    if cnf_like:
 67        pushdown_cnf(predicates, sources, scope_ref_count)
 68    else:
 69        pushdown_dnf(predicates, sources, scope_ref_count)
 70
 71
 72def pushdown_cnf(predicates, scope, scope_ref_count):
 73    """
 74    If the predicates are in CNF like form, we can simply replace each block in the parent.
 75    """
 76    for predicate in predicates:
 77        for node in nodes_for_predicate(predicate, scope, scope_ref_count).values():
 78            if isinstance(node, exp.Join):
 79                predicate.replace(exp.true())
 80                node.on(predicate, copy=False)
 81                break
 82            if isinstance(node, exp.Select):
 83                predicate.replace(exp.true())
 84                node.where(replace_aliases(node, predicate), copy=False)
 85
 86
 87def pushdown_dnf(predicates, scope, scope_ref_count):
 88    """
 89    If the predicates are in DNF form, we can only push down conditions that are in all blocks.
 90    Additionally, we can't remove predicates from their original form.
 91    """
 92    # find all the tables that can be pushdown too
 93    # these are tables that are referenced in all blocks of a DNF
 94    # (a.x AND b.x) OR (a.y AND c.y)
 95    # only table a can be push down
 96    pushdown_tables = set()
 97
 98    for a in predicates:
 99        a_tables = exp.column_table_names(a)
100
101        for b in predicates:
102            a_tables &= exp.column_table_names(b)
103
104        pushdown_tables.update(a_tables)
105
106    conditions = {}
107
108    # for every pushdown table, find all related conditions in all predicates
109    # combine them with ORS
110    # (a.x AND and a.y AND b.x) OR (a.z AND c.y) -> (a.x AND a.y) OR (a.z)
111    for table in sorted(pushdown_tables):
112        for predicate in predicates:
113            nodes = nodes_for_predicate(predicate, scope, scope_ref_count)
114
115            if table not in nodes:
116                continue
117
118            predicate_condition = None
119
120            for column in predicate.find_all(exp.Column):
121                if column.table == table:
122                    condition = column.find_ancestor(exp.Condition)
123                    predicate_condition = (
124                        exp.and_(predicate_condition, condition)
125                        if predicate_condition
126                        else condition
127                    )
128
129            if predicate_condition:
130                conditions[table] = (
131                    exp.or_(conditions[table], predicate_condition)
132                    if table in conditions
133                    else predicate_condition
134                )
135
136        for name, node in nodes.items():
137            if name not in conditions:
138                continue
139
140            predicate = conditions[name]
141
142            if isinstance(node, exp.Join):
143                node.on(predicate, copy=False)
144            elif isinstance(node, exp.Select):
145                node.where(replace_aliases(node, predicate), copy=False)
146
147
148def nodes_for_predicate(predicate, sources, scope_ref_count):
149    nodes = {}
150    tables = exp.column_table_names(predicate)
151    where_condition = isinstance(predicate.find_ancestor(exp.Join, exp.Where), exp.Where)
152
153    for table in sorted(tables):
154        node, source = sources.get(table) or (None, None)
155
156        # if the predicate is in a where statement we can try to push it down
157        # we want to find the root join or from statement
158        if node and where_condition:
159            node = node.find_ancestor(exp.Join, exp.From)
160
161        # a node can reference a CTE which should be pushed down
162        if isinstance(node, exp.From) and not isinstance(source, exp.Table):
163            with_ = source.parent.expression.args.get("with")
164            if with_ and with_.recursive:
165                return {}
166            node = source.expression
167
168        if isinstance(node, exp.Join):
169            if node.side and node.side != "RIGHT":
170                return {}
171            nodes[table] = node
172        elif isinstance(node, exp.Select) and len(tables) == 1:
173            # We can't push down window expressions
174            has_window_expression = any(
175                select for select in node.selects if select.find(exp.Window)
176            )
177            # we can't push down predicates to select statements if they are referenced in
178            # multiple places.
179            if (
180                not node.args.get("group")
181                and scope_ref_count[id(source)] < 2
182                and not has_window_expression
183            ):
184                nodes[table] = node
185    return nodes
186
187
188def replace_aliases(source, predicate):
189    aliases = {}
190
191    for select in source.selects:
192        if isinstance(select, exp.Alias):
193            aliases[select.alias] = select.this
194        else:
195            aliases[select.name] = select
196
197    def _replace_alias(column):
198        if isinstance(column, exp.Column) and column.name in aliases:
199            return aliases[column.name].copy()
200        return column
201
202    return predicate.transform(_replace_alias)
def pushdown_predicates(expression):
 8def pushdown_predicates(expression):
 9    """
10    Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS
11
12    Example:
13        >>> import sqlglot
14        >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y WHERE y.a = 1"
15        >>> expression = sqlglot.parse_one(sql)
16        >>> pushdown_predicates(expression).sql()
17        'SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x WHERE x.a = 1) AS y WHERE TRUE'
18
19    Args:
20        expression (sqlglot.Expression): expression to optimize
21    Returns:
22        sqlglot.Expression: optimized expression
23    """
24    root = build_scope(expression)
25
26    if root:
27        scope_ref_count = root.ref_count()
28
29        for scope in reversed(list(root.traverse())):
30            select = scope.expression
31            where = select.args.get("where")
32            if where:
33                selected_sources = scope.selected_sources
34                # a right join can only push down to itself and not the source FROM table
35                for k, (node, source) in selected_sources.items():
36                    parent = node.find_ancestor(exp.Join, exp.From)
37                    if isinstance(parent, exp.Join) and parent.side == "RIGHT":
38                        selected_sources = {k: (node, source)}
39                        break
40                pushdown(where.this, selected_sources, scope_ref_count)
41
42            # joins should only pushdown into itself, not to other joins
43            # so we limit the selected sources to only itself
44            for join in select.args.get("joins") or []:
45                name = join.alias_or_name
46                if name in scope.selected_sources:
47                    pushdown(
48                        join.args.get("on"), {name: scope.selected_sources[name]}, scope_ref_count
49                    )
50
51    return expression

Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS

Example:
>>> import sqlglot
>>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y WHERE y.a = 1"
>>> expression = sqlglot.parse_one(sql)
>>> pushdown_predicates(expression).sql()
'SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x WHERE x.a = 1) AS y WHERE TRUE'
Arguments:
  • expression (sqlglot.Expression): expression to optimize
Returns:

sqlglot.Expression: optimized expression

def pushdown(condition, sources, scope_ref_count):
54def pushdown(condition, sources, scope_ref_count):
55    if not condition:
56        return
57
58    condition = condition.replace(simplify(condition))
59    cnf_like = normalized(condition) or not normalized(condition, dnf=True)
60
61    predicates = list(
62        condition.flatten()
63        if isinstance(condition, exp.And if cnf_like else exp.Or)
64        else [condition]
65    )
66
67    if cnf_like:
68        pushdown_cnf(predicates, sources, scope_ref_count)
69    else:
70        pushdown_dnf(predicates, sources, scope_ref_count)
def pushdown_cnf(predicates, scope, scope_ref_count):
73def pushdown_cnf(predicates, scope, scope_ref_count):
74    """
75    If the predicates are in CNF like form, we can simply replace each block in the parent.
76    """
77    for predicate in predicates:
78        for node in nodes_for_predicate(predicate, scope, scope_ref_count).values():
79            if isinstance(node, exp.Join):
80                predicate.replace(exp.true())
81                node.on(predicate, copy=False)
82                break
83            if isinstance(node, exp.Select):
84                predicate.replace(exp.true())
85                node.where(replace_aliases(node, predicate), copy=False)

If the predicates are in CNF like form, we can simply replace each block in the parent.

def pushdown_dnf(predicates, scope, scope_ref_count):
 88def pushdown_dnf(predicates, scope, scope_ref_count):
 89    """
 90    If the predicates are in DNF form, we can only push down conditions that are in all blocks.
 91    Additionally, we can't remove predicates from their original form.
 92    """
 93    # find all the tables that can be pushdown too
 94    # these are tables that are referenced in all blocks of a DNF
 95    # (a.x AND b.x) OR (a.y AND c.y)
 96    # only table a can be push down
 97    pushdown_tables = set()
 98
 99    for a in predicates:
100        a_tables = exp.column_table_names(a)
101
102        for b in predicates:
103            a_tables &= exp.column_table_names(b)
104
105        pushdown_tables.update(a_tables)
106
107    conditions = {}
108
109    # for every pushdown table, find all related conditions in all predicates
110    # combine them with ORS
111    # (a.x AND and a.y AND b.x) OR (a.z AND c.y) -> (a.x AND a.y) OR (a.z)
112    for table in sorted(pushdown_tables):
113        for predicate in predicates:
114            nodes = nodes_for_predicate(predicate, scope, scope_ref_count)
115
116            if table not in nodes:
117                continue
118
119            predicate_condition = None
120
121            for column in predicate.find_all(exp.Column):
122                if column.table == table:
123                    condition = column.find_ancestor(exp.Condition)
124                    predicate_condition = (
125                        exp.and_(predicate_condition, condition)
126                        if predicate_condition
127                        else condition
128                    )
129
130            if predicate_condition:
131                conditions[table] = (
132                    exp.or_(conditions[table], predicate_condition)
133                    if table in conditions
134                    else predicate_condition
135                )
136
137        for name, node in nodes.items():
138            if name not in conditions:
139                continue
140
141            predicate = conditions[name]
142
143            if isinstance(node, exp.Join):
144                node.on(predicate, copy=False)
145            elif isinstance(node, exp.Select):
146                node.where(replace_aliases(node, predicate), copy=False)

If the predicates are in DNF form, we can only push down conditions that are in all blocks. Additionally, we can't remove predicates from their original form.

def nodes_for_predicate(predicate, sources, scope_ref_count):
149def nodes_for_predicate(predicate, sources, scope_ref_count):
150    nodes = {}
151    tables = exp.column_table_names(predicate)
152    where_condition = isinstance(predicate.find_ancestor(exp.Join, exp.Where), exp.Where)
153
154    for table in sorted(tables):
155        node, source = sources.get(table) or (None, None)
156
157        # if the predicate is in a where statement we can try to push it down
158        # we want to find the root join or from statement
159        if node and where_condition:
160            node = node.find_ancestor(exp.Join, exp.From)
161
162        # a node can reference a CTE which should be pushed down
163        if isinstance(node, exp.From) and not isinstance(source, exp.Table):
164            with_ = source.parent.expression.args.get("with")
165            if with_ and with_.recursive:
166                return {}
167            node = source.expression
168
169        if isinstance(node, exp.Join):
170            if node.side and node.side != "RIGHT":
171                return {}
172            nodes[table] = node
173        elif isinstance(node, exp.Select) and len(tables) == 1:
174            # We can't push down window expressions
175            has_window_expression = any(
176                select for select in node.selects if select.find(exp.Window)
177            )
178            # we can't push down predicates to select statements if they are referenced in
179            # multiple places.
180            if (
181                not node.args.get("group")
182                and scope_ref_count[id(source)] < 2
183                and not has_window_expression
184            ):
185                nodes[table] = node
186    return nodes
def replace_aliases(source, predicate):
189def replace_aliases(source, predicate):
190    aliases = {}
191
192    for select in source.selects:
193        if isinstance(select, exp.Alias):
194            aliases[select.alias] = select.this
195        else:
196            aliases[select.name] = select
197
198    def _replace_alias(column):
199        if isinstance(column, exp.Column) and column.name in aliases:
200            return aliases[column.name].copy()
201        return column
202
203    return predicate.transform(_replace_alias)