Edit on GitHub

sqlglot.optimizer.pushdown_predicates

  1from sqlglot import exp
  2from sqlglot.optimizer.normalize import normalized
  3from sqlglot.optimizer.scope import build_scope, find_in_scope
  4from sqlglot.optimizer.simplify import simplify
  5
  6
  7def pushdown_predicates(expression):
  8    """
  9    Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS
 10
 11    Example:
 12        >>> import sqlglot
 13        >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y WHERE y.a = 1"
 14        >>> expression = sqlglot.parse_one(sql)
 15        >>> pushdown_predicates(expression).sql()
 16        'SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x WHERE x.a = 1) AS y WHERE TRUE'
 17
 18    Args:
 19        expression (sqlglot.Expression): expression to optimize
 20    Returns:
 21        sqlglot.Expression: optimized expression
 22    """
 23    root = build_scope(expression)
 24
 25    if root:
 26        scope_ref_count = root.ref_count()
 27
 28        for scope in reversed(list(root.traverse())):
 29            select = scope.expression
 30            where = select.args.get("where")
 31            if where:
 32                selected_sources = scope.selected_sources
 33                # a right join can only push down to itself and not the source FROM table
 34                for k, (node, source) in selected_sources.items():
 35                    parent = node.find_ancestor(exp.Join, exp.From)
 36                    if isinstance(parent, exp.Join) and parent.side == "RIGHT":
 37                        selected_sources = {k: (node, source)}
 38                        break
 39                pushdown(where.this, selected_sources, scope_ref_count)
 40
 41            # joins should only pushdown into itself, not to other joins
 42            # so we limit the selected sources to only itself
 43            for join in select.args.get("joins") or []:
 44                name = join.alias_or_name
 45                if name in scope.selected_sources:
 46                    pushdown(
 47                        join.args.get("on"), {name: scope.selected_sources[name]}, scope_ref_count
 48                    )
 49
 50    return expression
 51
 52
 53def pushdown(condition, sources, scope_ref_count):
 54    if not condition:
 55        return
 56
 57    condition = condition.replace(simplify(condition))
 58    cnf_like = normalized(condition) or not normalized(condition, dnf=True)
 59
 60    predicates = list(
 61        condition.flatten()
 62        if isinstance(condition, exp.And if cnf_like else exp.Or)
 63        else [condition]
 64    )
 65
 66    if cnf_like:
 67        pushdown_cnf(predicates, sources, scope_ref_count)
 68    else:
 69        pushdown_dnf(predicates, sources, scope_ref_count)
 70
 71
 72def pushdown_cnf(predicates, scope, scope_ref_count):
 73    """
 74    If the predicates are in CNF like form, we can simply replace each block in the parent.
 75    """
 76    for predicate in predicates:
 77        for node in nodes_for_predicate(predicate, scope, scope_ref_count).values():
 78            if isinstance(node, exp.Join):
 79                predicate.replace(exp.true())
 80                node.on(predicate, copy=False)
 81                break
 82            if isinstance(node, exp.Select):
 83                predicate.replace(exp.true())
 84                inner_predicate = replace_aliases(node, predicate)
 85                if find_in_scope(inner_predicate, exp.AggFunc):
 86                    node.having(inner_predicate, copy=False)
 87                else:
 88                    node.where(inner_predicate, copy=False)
 89
 90
 91def pushdown_dnf(predicates, scope, scope_ref_count):
 92    """
 93    If the predicates are in DNF form, we can only push down conditions that are in all blocks.
 94    Additionally, we can't remove predicates from their original form.
 95    """
 96    # find all the tables that can be pushdown too
 97    # these are tables that are referenced in all blocks of a DNF
 98    # (a.x AND b.x) OR (a.y AND c.y)
 99    # only table a can be push down
100    pushdown_tables = set()
101
102    for a in predicates:
103        a_tables = exp.column_table_names(a)
104
105        for b in predicates:
106            a_tables &= exp.column_table_names(b)
107
108        pushdown_tables.update(a_tables)
109
110    conditions = {}
111
112    # for every pushdown table, find all related conditions in all predicates
113    # combine them with ORS
114    # (a.x AND and a.y AND b.x) OR (a.z AND c.y) -> (a.x AND a.y) OR (a.z)
115    for table in sorted(pushdown_tables):
116        for predicate in predicates:
117            nodes = nodes_for_predicate(predicate, scope, scope_ref_count)
118
119            if table not in nodes:
120                continue
121
122            predicate_condition = None
123
124            for column in predicate.find_all(exp.Column):
125                if column.table == table:
126                    condition = column.find_ancestor(exp.Condition)
127                    predicate_condition = (
128                        exp.and_(predicate_condition, condition)
129                        if predicate_condition
130                        else condition
131                    )
132
133            if predicate_condition:
134                conditions[table] = (
135                    exp.or_(conditions[table], predicate_condition)
136                    if table in conditions
137                    else predicate_condition
138                )
139
140        for name, node in nodes.items():
141            if name not in conditions:
142                continue
143
144            predicate = conditions[name]
145
146            if isinstance(node, exp.Join):
147                node.on(predicate, copy=False)
148            elif isinstance(node, exp.Select):
149                inner_predicate = replace_aliases(node, predicate)
150                if find_in_scope(inner_predicate, exp.AggFunc):
151                    node.having(inner_predicate, copy=False)
152                else:
153                    node.where(inner_predicate, copy=False)
154
155
156def nodes_for_predicate(predicate, sources, scope_ref_count):
157    nodes = {}
158    tables = exp.column_table_names(predicate)
159    where_condition = isinstance(predicate.find_ancestor(exp.Join, exp.Where), exp.Where)
160
161    for table in sorted(tables):
162        node, source = sources.get(table) or (None, None)
163
164        # if the predicate is in a where statement we can try to push it down
165        # we want to find the root join or from statement
166        if node and where_condition:
167            node = node.find_ancestor(exp.Join, exp.From)
168
169        # a node can reference a CTE which should be pushed down
170        if isinstance(node, exp.From) and not isinstance(source, exp.Table):
171            with_ = source.parent.expression.args.get("with")
172            if with_ and with_.recursive:
173                return {}
174            node = source.expression
175
176        if isinstance(node, exp.Join):
177            if node.side and node.side != "RIGHT":
178                return {}
179            nodes[table] = node
180        elif isinstance(node, exp.Select) and len(tables) == 1:
181            # We can't push down window expressions
182            has_window_expression = any(
183                select for select in node.selects if select.find(exp.Window)
184            )
185            # we can't push down predicates to select statements if they are referenced in
186            # multiple places.
187            if (
188                not node.args.get("group")
189                and scope_ref_count[id(source)] < 2
190                and not has_window_expression
191            ):
192                nodes[table] = node
193    return nodes
194
195
196def replace_aliases(source, predicate):
197    aliases = {}
198
199    for select in source.selects:
200        if isinstance(select, exp.Alias):
201            aliases[select.alias] = select.this
202        else:
203            aliases[select.name] = select
204
205    def _replace_alias(column):
206        if isinstance(column, exp.Column) and column.name in aliases:
207            return aliases[column.name].copy()
208        return column
209
210    return predicate.transform(_replace_alias)
def pushdown_predicates(expression):
 8def pushdown_predicates(expression):
 9    """
10    Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS
11
12    Example:
13        >>> import sqlglot
14        >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y WHERE y.a = 1"
15        >>> expression = sqlglot.parse_one(sql)
16        >>> pushdown_predicates(expression).sql()
17        'SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x WHERE x.a = 1) AS y WHERE TRUE'
18
19    Args:
20        expression (sqlglot.Expression): expression to optimize
21    Returns:
22        sqlglot.Expression: optimized expression
23    """
24    root = build_scope(expression)
25
26    if root:
27        scope_ref_count = root.ref_count()
28
29        for scope in reversed(list(root.traverse())):
30            select = scope.expression
31            where = select.args.get("where")
32            if where:
33                selected_sources = scope.selected_sources
34                # a right join can only push down to itself and not the source FROM table
35                for k, (node, source) in selected_sources.items():
36                    parent = node.find_ancestor(exp.Join, exp.From)
37                    if isinstance(parent, exp.Join) and parent.side == "RIGHT":
38                        selected_sources = {k: (node, source)}
39                        break
40                pushdown(where.this, selected_sources, scope_ref_count)
41
42            # joins should only pushdown into itself, not to other joins
43            # so we limit the selected sources to only itself
44            for join in select.args.get("joins") or []:
45                name = join.alias_or_name
46                if name in scope.selected_sources:
47                    pushdown(
48                        join.args.get("on"), {name: scope.selected_sources[name]}, scope_ref_count
49                    )
50
51    return expression

Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS

Example:
>>> import sqlglot
>>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y WHERE y.a = 1"
>>> expression = sqlglot.parse_one(sql)
>>> pushdown_predicates(expression).sql()
'SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x WHERE x.a = 1) AS y WHERE TRUE'
Arguments:
  • expression (sqlglot.Expression): expression to optimize
Returns:

sqlglot.Expression: optimized expression

def pushdown(condition, sources, scope_ref_count):
54def pushdown(condition, sources, scope_ref_count):
55    if not condition:
56        return
57
58    condition = condition.replace(simplify(condition))
59    cnf_like = normalized(condition) or not normalized(condition, dnf=True)
60
61    predicates = list(
62        condition.flatten()
63        if isinstance(condition, exp.And if cnf_like else exp.Or)
64        else [condition]
65    )
66
67    if cnf_like:
68        pushdown_cnf(predicates, sources, scope_ref_count)
69    else:
70        pushdown_dnf(predicates, sources, scope_ref_count)
def pushdown_cnf(predicates, scope, scope_ref_count):
73def pushdown_cnf(predicates, scope, scope_ref_count):
74    """
75    If the predicates are in CNF like form, we can simply replace each block in the parent.
76    """
77    for predicate in predicates:
78        for node in nodes_for_predicate(predicate, scope, scope_ref_count).values():
79            if isinstance(node, exp.Join):
80                predicate.replace(exp.true())
81                node.on(predicate, copy=False)
82                break
83            if isinstance(node, exp.Select):
84                predicate.replace(exp.true())
85                inner_predicate = replace_aliases(node, predicate)
86                if find_in_scope(inner_predicate, exp.AggFunc):
87                    node.having(inner_predicate, copy=False)
88                else:
89                    node.where(inner_predicate, copy=False)

If the predicates are in CNF like form, we can simply replace each block in the parent.

def pushdown_dnf(predicates, scope, scope_ref_count):
 92def pushdown_dnf(predicates, scope, scope_ref_count):
 93    """
 94    If the predicates are in DNF form, we can only push down conditions that are in all blocks.
 95    Additionally, we can't remove predicates from their original form.
 96    """
 97    # find all the tables that can be pushdown too
 98    # these are tables that are referenced in all blocks of a DNF
 99    # (a.x AND b.x) OR (a.y AND c.y)
100    # only table a can be push down
101    pushdown_tables = set()
102
103    for a in predicates:
104        a_tables = exp.column_table_names(a)
105
106        for b in predicates:
107            a_tables &= exp.column_table_names(b)
108
109        pushdown_tables.update(a_tables)
110
111    conditions = {}
112
113    # for every pushdown table, find all related conditions in all predicates
114    # combine them with ORS
115    # (a.x AND and a.y AND b.x) OR (a.z AND c.y) -> (a.x AND a.y) OR (a.z)
116    for table in sorted(pushdown_tables):
117        for predicate in predicates:
118            nodes = nodes_for_predicate(predicate, scope, scope_ref_count)
119
120            if table not in nodes:
121                continue
122
123            predicate_condition = None
124
125            for column in predicate.find_all(exp.Column):
126                if column.table == table:
127                    condition = column.find_ancestor(exp.Condition)
128                    predicate_condition = (
129                        exp.and_(predicate_condition, condition)
130                        if predicate_condition
131                        else condition
132                    )
133
134            if predicate_condition:
135                conditions[table] = (
136                    exp.or_(conditions[table], predicate_condition)
137                    if table in conditions
138                    else predicate_condition
139                )
140
141        for name, node in nodes.items():
142            if name not in conditions:
143                continue
144
145            predicate = conditions[name]
146
147            if isinstance(node, exp.Join):
148                node.on(predicate, copy=False)
149            elif isinstance(node, exp.Select):
150                inner_predicate = replace_aliases(node, predicate)
151                if find_in_scope(inner_predicate, exp.AggFunc):
152                    node.having(inner_predicate, copy=False)
153                else:
154                    node.where(inner_predicate, copy=False)

If the predicates are in DNF form, we can only push down conditions that are in all blocks. Additionally, we can't remove predicates from their original form.

def nodes_for_predicate(predicate, sources, scope_ref_count):
157def nodes_for_predicate(predicate, sources, scope_ref_count):
158    nodes = {}
159    tables = exp.column_table_names(predicate)
160    where_condition = isinstance(predicate.find_ancestor(exp.Join, exp.Where), exp.Where)
161
162    for table in sorted(tables):
163        node, source = sources.get(table) or (None, None)
164
165        # if the predicate is in a where statement we can try to push it down
166        # we want to find the root join or from statement
167        if node and where_condition:
168            node = node.find_ancestor(exp.Join, exp.From)
169
170        # a node can reference a CTE which should be pushed down
171        if isinstance(node, exp.From) and not isinstance(source, exp.Table):
172            with_ = source.parent.expression.args.get("with")
173            if with_ and with_.recursive:
174                return {}
175            node = source.expression
176
177        if isinstance(node, exp.Join):
178            if node.side and node.side != "RIGHT":
179                return {}
180            nodes[table] = node
181        elif isinstance(node, exp.Select) and len(tables) == 1:
182            # We can't push down window expressions
183            has_window_expression = any(
184                select for select in node.selects if select.find(exp.Window)
185            )
186            # we can't push down predicates to select statements if they are referenced in
187            # multiple places.
188            if (
189                not node.args.get("group")
190                and scope_ref_count[id(source)] < 2
191                and not has_window_expression
192            ):
193                nodes[table] = node
194    return nodes
def replace_aliases(source, predicate):
197def replace_aliases(source, predicate):
198    aliases = {}
199
200    for select in source.selects:
201        if isinstance(select, exp.Alias):
202            aliases[select.alias] = select.this
203        else:
204            aliases[select.name] = select
205
206    def _replace_alias(column):
207        if isinstance(column, exp.Column) and column.name in aliases:
208            return aliases[column.name].copy()
209        return column
210
211    return predicate.transform(_replace_alias)