Edit on GitHub

sqlglot.optimizer.pushdown_predicates

  1from sqlglot import exp
  2from sqlglot.optimizer.normalize import normalized
  3from sqlglot.optimizer.scope import build_scope, find_in_scope
  4from sqlglot.optimizer.simplify import simplify
  5
  6
  7def pushdown_predicates(expression, dialect=None):
  8    """
  9    Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS
 10
 11    Example:
 12        >>> import sqlglot
 13        >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y WHERE y.a = 1"
 14        >>> expression = sqlglot.parse_one(sql)
 15        >>> pushdown_predicates(expression).sql()
 16        'SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x WHERE x.a = 1) AS y WHERE TRUE'
 17
 18    Args:
 19        expression (sqlglot.Expression): expression to optimize
 20    Returns:
 21        sqlglot.Expression: optimized expression
 22    """
 23    root = build_scope(expression)
 24
 25    if root:
 26        scope_ref_count = root.ref_count()
 27
 28        for scope in reversed(list(root.traverse())):
 29            select = scope.expression
 30            where = select.args.get("where")
 31            if where:
 32                selected_sources = scope.selected_sources
 33                join_index = {
 34                    join.alias_or_name: i for i, join in enumerate(select.args.get("joins") or [])
 35                }
 36
 37                # a right join can only push down to itself and not the source FROM table
 38                for k, (node, source) in selected_sources.items():
 39                    parent = node.find_ancestor(exp.Join, exp.From)
 40                    if isinstance(parent, exp.Join) and parent.side == "RIGHT":
 41                        selected_sources = {k: (node, source)}
 42                        break
 43
 44                pushdown(where.this, selected_sources, scope_ref_count, dialect, join_index)
 45
 46            # joins should only pushdown into itself, not to other joins
 47            # so we limit the selected sources to only itself
 48            for join in select.args.get("joins") or []:
 49                name = join.alias_or_name
 50                if name in scope.selected_sources:
 51                    pushdown(
 52                        join.args.get("on"),
 53                        {name: scope.selected_sources[name]},
 54                        scope_ref_count,
 55                        dialect,
 56                    )
 57
 58    return expression
 59
 60
 61def pushdown(condition, sources, scope_ref_count, dialect, join_index=None):
 62    if not condition:
 63        return
 64
 65    condition = condition.replace(simplify(condition, dialect=dialect))
 66    cnf_like = normalized(condition) or not normalized(condition, dnf=True)
 67
 68    predicates = list(
 69        condition.flatten()
 70        if isinstance(condition, exp.And if cnf_like else exp.Or)
 71        else [condition]
 72    )
 73
 74    if cnf_like:
 75        pushdown_cnf(predicates, sources, scope_ref_count, join_index=join_index)
 76    else:
 77        pushdown_dnf(predicates, sources, scope_ref_count)
 78
 79
 80def pushdown_cnf(predicates, sources, scope_ref_count, join_index=None):
 81    """
 82    If the predicates are in CNF like form, we can simply replace each block in the parent.
 83    """
 84    join_index = join_index or {}
 85    for predicate in predicates:
 86        for node in nodes_for_predicate(predicate, sources, scope_ref_count).values():
 87            if isinstance(node, exp.Join):
 88                name = node.alias_or_name
 89                predicate_tables = exp.column_table_names(predicate, name)
 90
 91                # Don't push the predicate if it references tables that appear in later joins
 92                this_index = join_index[name]
 93                if all(join_index.get(table, -1) < this_index for table in predicate_tables):
 94                    predicate.replace(exp.true())
 95                    node.on(predicate, copy=False)
 96                    break
 97            if isinstance(node, exp.Select):
 98                predicate.replace(exp.true())
 99                inner_predicate = replace_aliases(node, predicate)
100                if find_in_scope(inner_predicate, exp.AggFunc):
101                    node.having(inner_predicate, copy=False)
102                else:
103                    node.where(inner_predicate, copy=False)
104
105
106def pushdown_dnf(predicates, sources, scope_ref_count):
107    """
108    If the predicates are in DNF form, we can only push down conditions that are in all blocks.
109    Additionally, we can't remove predicates from their original form.
110    """
111    # find all the tables that can be pushdown too
112    # these are tables that are referenced in all blocks of a DNF
113    # (a.x AND b.x) OR (a.y AND c.y)
114    # only table a can be push down
115    pushdown_tables = set()
116
117    for a in predicates:
118        a_tables = exp.column_table_names(a)
119
120        for b in predicates:
121            a_tables &= exp.column_table_names(b)
122
123        pushdown_tables.update(a_tables)
124
125    conditions = {}
126
127    # pushdown all predicates to their respective nodes
128    for table in sorted(pushdown_tables):
129        for predicate in predicates:
130            nodes = nodes_for_predicate(predicate, sources, scope_ref_count)
131
132            if table not in nodes:
133                continue
134
135            conditions[table] = (
136                exp.or_(conditions[table], predicate) if table in conditions else predicate
137            )
138
139        for name, node in nodes.items():
140            if name not in conditions:
141                continue
142
143            predicate = conditions[name]
144
145            if isinstance(node, exp.Join):
146                node.on(predicate, copy=False)
147            elif isinstance(node, exp.Select):
148                inner_predicate = replace_aliases(node, predicate)
149                if find_in_scope(inner_predicate, exp.AggFunc):
150                    node.having(inner_predicate, copy=False)
151                else:
152                    node.where(inner_predicate, copy=False)
153
154
155def nodes_for_predicate(predicate, sources, scope_ref_count):
156    nodes = {}
157    tables = exp.column_table_names(predicate)
158    where_condition = isinstance(predicate.find_ancestor(exp.Join, exp.Where), exp.Where)
159
160    for table in sorted(tables):
161        node, source = sources.get(table) or (None, None)
162
163        # if the predicate is in a where statement we can try to push it down
164        # we want to find the root join or from statement
165        if node and where_condition:
166            node = node.find_ancestor(exp.Join, exp.From)
167
168        # a node can reference a CTE which should be pushed down
169        if isinstance(node, exp.From) and not isinstance(source, exp.Table):
170            with_ = source.parent.expression.args.get("with")
171            if with_ and with_.recursive:
172                return {}
173            node = source.expression
174
175        if isinstance(node, exp.Join):
176            if node.side and node.side != "RIGHT":
177                return {}
178            nodes[table] = node
179        elif isinstance(node, exp.Select) and len(tables) == 1:
180            # We can't push down window expressions
181            has_window_expression = any(
182                select for select in node.selects if select.find(exp.Window)
183            )
184            # we can't push down predicates to select statements if they are referenced in
185            # multiple places.
186            if (
187                not node.args.get("group")
188                and scope_ref_count[id(source)] < 2
189                and not has_window_expression
190            ):
191                nodes[table] = node
192    return nodes
193
194
195def replace_aliases(source, predicate):
196    aliases = {}
197
198    for select in source.selects:
199        if isinstance(select, exp.Alias):
200            aliases[select.alias] = select.this
201        else:
202            aliases[select.name] = select
203
204    def _replace_alias(column):
205        if isinstance(column, exp.Column) and column.name in aliases:
206            return aliases[column.name].copy()
207        return column
208
209    return predicate.transform(_replace_alias)
def pushdown_predicates(expression, dialect=None):
 8def pushdown_predicates(expression, dialect=None):
 9    """
10    Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS
11
12    Example:
13        >>> import sqlglot
14        >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y WHERE y.a = 1"
15        >>> expression = sqlglot.parse_one(sql)
16        >>> pushdown_predicates(expression).sql()
17        'SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x WHERE x.a = 1) AS y WHERE TRUE'
18
19    Args:
20        expression (sqlglot.Expression): expression to optimize
21    Returns:
22        sqlglot.Expression: optimized expression
23    """
24    root = build_scope(expression)
25
26    if root:
27        scope_ref_count = root.ref_count()
28
29        for scope in reversed(list(root.traverse())):
30            select = scope.expression
31            where = select.args.get("where")
32            if where:
33                selected_sources = scope.selected_sources
34                join_index = {
35                    join.alias_or_name: i for i, join in enumerate(select.args.get("joins") or [])
36                }
37
38                # a right join can only push down to itself and not the source FROM table
39                for k, (node, source) in selected_sources.items():
40                    parent = node.find_ancestor(exp.Join, exp.From)
41                    if isinstance(parent, exp.Join) and parent.side == "RIGHT":
42                        selected_sources = {k: (node, source)}
43                        break
44
45                pushdown(where.this, selected_sources, scope_ref_count, dialect, join_index)
46
47            # joins should only pushdown into itself, not to other joins
48            # so we limit the selected sources to only itself
49            for join in select.args.get("joins") or []:
50                name = join.alias_or_name
51                if name in scope.selected_sources:
52                    pushdown(
53                        join.args.get("on"),
54                        {name: scope.selected_sources[name]},
55                        scope_ref_count,
56                        dialect,
57                    )
58
59    return expression

Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS

Example:
>>> import sqlglot
>>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y WHERE y.a = 1"
>>> expression = sqlglot.parse_one(sql)
>>> pushdown_predicates(expression).sql()
'SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x WHERE x.a = 1) AS y WHERE TRUE'
Arguments:
  • expression (sqlglot.Expression): expression to optimize
Returns:

sqlglot.Expression: optimized expression

def pushdown(condition, sources, scope_ref_count, dialect, join_index=None):
62def pushdown(condition, sources, scope_ref_count, dialect, join_index=None):
63    if not condition:
64        return
65
66    condition = condition.replace(simplify(condition, dialect=dialect))
67    cnf_like = normalized(condition) or not normalized(condition, dnf=True)
68
69    predicates = list(
70        condition.flatten()
71        if isinstance(condition, exp.And if cnf_like else exp.Or)
72        else [condition]
73    )
74
75    if cnf_like:
76        pushdown_cnf(predicates, sources, scope_ref_count, join_index=join_index)
77    else:
78        pushdown_dnf(predicates, sources, scope_ref_count)
def pushdown_cnf(predicates, sources, scope_ref_count, join_index=None):
 81def pushdown_cnf(predicates, sources, scope_ref_count, join_index=None):
 82    """
 83    If the predicates are in CNF like form, we can simply replace each block in the parent.
 84    """
 85    join_index = join_index or {}
 86    for predicate in predicates:
 87        for node in nodes_for_predicate(predicate, sources, scope_ref_count).values():
 88            if isinstance(node, exp.Join):
 89                name = node.alias_or_name
 90                predicate_tables = exp.column_table_names(predicate, name)
 91
 92                # Don't push the predicate if it references tables that appear in later joins
 93                this_index = join_index[name]
 94                if all(join_index.get(table, -1) < this_index for table in predicate_tables):
 95                    predicate.replace(exp.true())
 96                    node.on(predicate, copy=False)
 97                    break
 98            if isinstance(node, exp.Select):
 99                predicate.replace(exp.true())
100                inner_predicate = replace_aliases(node, predicate)
101                if find_in_scope(inner_predicate, exp.AggFunc):
102                    node.having(inner_predicate, copy=False)
103                else:
104                    node.where(inner_predicate, copy=False)

If the predicates are in CNF like form, we can simply replace each block in the parent.

def pushdown_dnf(predicates, sources, scope_ref_count):
107def pushdown_dnf(predicates, sources, scope_ref_count):
108    """
109    If the predicates are in DNF form, we can only push down conditions that are in all blocks.
110    Additionally, we can't remove predicates from their original form.
111    """
112    # find all the tables that can be pushdown too
113    # these are tables that are referenced in all blocks of a DNF
114    # (a.x AND b.x) OR (a.y AND c.y)
115    # only table a can be push down
116    pushdown_tables = set()
117
118    for a in predicates:
119        a_tables = exp.column_table_names(a)
120
121        for b in predicates:
122            a_tables &= exp.column_table_names(b)
123
124        pushdown_tables.update(a_tables)
125
126    conditions = {}
127
128    # pushdown all predicates to their respective nodes
129    for table in sorted(pushdown_tables):
130        for predicate in predicates:
131            nodes = nodes_for_predicate(predicate, sources, scope_ref_count)
132
133            if table not in nodes:
134                continue
135
136            conditions[table] = (
137                exp.or_(conditions[table], predicate) if table in conditions else predicate
138            )
139
140        for name, node in nodes.items():
141            if name not in conditions:
142                continue
143
144            predicate = conditions[name]
145
146            if isinstance(node, exp.Join):
147                node.on(predicate, copy=False)
148            elif isinstance(node, exp.Select):
149                inner_predicate = replace_aliases(node, predicate)
150                if find_in_scope(inner_predicate, exp.AggFunc):
151                    node.having(inner_predicate, copy=False)
152                else:
153                    node.where(inner_predicate, copy=False)

If the predicates are in DNF form, we can only push down conditions that are in all blocks. Additionally, we can't remove predicates from their original form.

def nodes_for_predicate(predicate, sources, scope_ref_count):
156def nodes_for_predicate(predicate, sources, scope_ref_count):
157    nodes = {}
158    tables = exp.column_table_names(predicate)
159    where_condition = isinstance(predicate.find_ancestor(exp.Join, exp.Where), exp.Where)
160
161    for table in sorted(tables):
162        node, source = sources.get(table) or (None, None)
163
164        # if the predicate is in a where statement we can try to push it down
165        # we want to find the root join or from statement
166        if node and where_condition:
167            node = node.find_ancestor(exp.Join, exp.From)
168
169        # a node can reference a CTE which should be pushed down
170        if isinstance(node, exp.From) and not isinstance(source, exp.Table):
171            with_ = source.parent.expression.args.get("with")
172            if with_ and with_.recursive:
173                return {}
174            node = source.expression
175
176        if isinstance(node, exp.Join):
177            if node.side and node.side != "RIGHT":
178                return {}
179            nodes[table] = node
180        elif isinstance(node, exp.Select) and len(tables) == 1:
181            # We can't push down window expressions
182            has_window_expression = any(
183                select for select in node.selects if select.find(exp.Window)
184            )
185            # we can't push down predicates to select statements if they are referenced in
186            # multiple places.
187            if (
188                not node.args.get("group")
189                and scope_ref_count[id(source)] < 2
190                and not has_window_expression
191            ):
192                nodes[table] = node
193    return nodes
def replace_aliases(source, predicate):
196def replace_aliases(source, predicate):
197    aliases = {}
198
199    for select in source.selects:
200        if isinstance(select, exp.Alias):
201            aliases[select.alias] = select.this
202        else:
203            aliases[select.name] = select
204
205    def _replace_alias(column):
206        if isinstance(column, exp.Column) and column.name in aliases:
207            return aliases[column.name].copy()
208        return column
209
210    return predicate.transform(_replace_alias)