Edit on GitHub

sqlglot.optimizer.pushdown_predicates

  1from sqlglot import exp
  2from sqlglot.optimizer.normalize import normalized
  3from sqlglot.optimizer.scope import build_scope
  4from sqlglot.optimizer.simplify import simplify
  5
  6
  7def pushdown_predicates(expression):
  8    """
  9    Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS
 10
 11    Example:
 12        >>> import sqlglot
 13        >>> sql = "SELECT * FROM (SELECT * FROM x AS x) AS y WHERE y.a = 1"
 14        >>> expression = sqlglot.parse_one(sql)
 15        >>> pushdown_predicates(expression).sql()
 16        'SELECT * FROM (SELECT * FROM x AS x WHERE y.a = 1) AS y WHERE TRUE'
 17
 18    Args:
 19        expression (sqlglot.Expression): expression to optimize
 20    Returns:
 21        sqlglot.Expression: optimized expression
 22    """
 23    root = build_scope(expression)
 24    scope_ref_count = root.ref_count()
 25
 26    for scope in reversed(list(root.traverse())):
 27        select = scope.expression
 28        where = select.args.get("where")
 29        if where:
 30            selected_sources = scope.selected_sources
 31            # a right join can only push down to itself and not the source FROM table
 32            for k, (node, source) in selected_sources.items():
 33                parent = node.find_ancestor(exp.Join, exp.From)
 34                if isinstance(parent, exp.Join) and parent.side == "RIGHT":
 35                    selected_sources = {k: (node, source)}
 36                    break
 37            pushdown(where.this, selected_sources, scope_ref_count)
 38
 39        # joins should only pushdown into itself, not to other joins
 40        # so we limit the selected sources to only itself
 41        for join in select.args.get("joins") or []:
 42            name = join.this.alias_or_name
 43            pushdown(join.args.get("on"), {name: scope.selected_sources[name]}, scope_ref_count)
 44
 45    return expression
 46
 47
 48def pushdown(condition, sources, scope_ref_count):
 49    if not condition:
 50        return
 51
 52    condition = condition.replace(simplify(condition))
 53    cnf_like = normalized(condition) or not normalized(condition, dnf=True)
 54
 55    predicates = list(
 56        condition.flatten()
 57        if isinstance(condition, exp.And if cnf_like else exp.Or)
 58        else [condition]
 59    )
 60
 61    if cnf_like:
 62        pushdown_cnf(predicates, sources, scope_ref_count)
 63    else:
 64        pushdown_dnf(predicates, sources, scope_ref_count)
 65
 66
 67def pushdown_cnf(predicates, scope, scope_ref_count):
 68    """
 69    If the predicates are in CNF like form, we can simply replace each block in the parent.
 70    """
 71    for predicate in predicates:
 72        for node in nodes_for_predicate(predicate, scope, scope_ref_count).values():
 73            if isinstance(node, exp.Join):
 74                predicate.replace(exp.true())
 75                node.on(predicate, copy=False)
 76                break
 77            if isinstance(node, exp.Select):
 78                predicate.replace(exp.true())
 79                node.where(replace_aliases(node, predicate), copy=False)
 80
 81
 82def pushdown_dnf(predicates, scope, scope_ref_count):
 83    """
 84    If the predicates are in DNF form, we can only push down conditions that are in all blocks.
 85    Additionally, we can't remove predicates from their original form.
 86    """
 87    # find all the tables that can be pushdown too
 88    # these are tables that are referenced in all blocks of a DNF
 89    # (a.x AND b.x) OR (a.y AND c.y)
 90    # only table a can be push down
 91    pushdown_tables = set()
 92
 93    for a in predicates:
 94        a_tables = set(exp.column_table_names(a))
 95
 96        for b in predicates:
 97            a_tables &= set(exp.column_table_names(b))
 98
 99        pushdown_tables.update(a_tables)
100
101    conditions = {}
102
103    # for every pushdown table, find all related conditions in all predicates
104    # combine them with ORS
105    # (a.x AND and a.y AND b.x) OR (a.z AND c.y) -> (a.x AND a.y) OR (a.z)
106    for table in sorted(pushdown_tables):
107        for predicate in predicates:
108            nodes = nodes_for_predicate(predicate, scope, scope_ref_count)
109
110            if table not in nodes:
111                continue
112
113            predicate_condition = None
114
115            for column in predicate.find_all(exp.Column):
116                if column.table == table:
117                    condition = column.find_ancestor(exp.Condition)
118                    predicate_condition = (
119                        exp.and_(predicate_condition, condition)
120                        if predicate_condition
121                        else condition
122                    )
123
124            if predicate_condition:
125                conditions[table] = (
126                    exp.or_(conditions[table], predicate_condition)
127                    if table in conditions
128                    else predicate_condition
129                )
130
131        for name, node in nodes.items():
132            if name not in conditions:
133                continue
134
135            predicate = conditions[name]
136
137            if isinstance(node, exp.Join):
138                node.on(predicate, copy=False)
139            elif isinstance(node, exp.Select):
140                node.where(replace_aliases(node, predicate), copy=False)
141
142
143def nodes_for_predicate(predicate, sources, scope_ref_count):
144    nodes = {}
145    tables = exp.column_table_names(predicate)
146    where_condition = isinstance(predicate.find_ancestor(exp.Join, exp.Where), exp.Where)
147
148    for table in tables:
149        node, source = sources.get(table) or (None, None)
150
151        # if the predicate is in a where statement we can try to push it down
152        # we want to find the root join or from statement
153        if node and where_condition:
154            node = node.find_ancestor(exp.Join, exp.From)
155
156        # a node can reference a CTE which should be pushed down
157        if isinstance(node, exp.From) and not isinstance(source, exp.Table):
158            with_ = source.parent.expression.args.get("with")
159            if with_ and with_.recursive:
160                return {}
161            node = source.expression
162
163        if isinstance(node, exp.Join):
164            if node.side and node.side != "RIGHT":
165                return {}
166            nodes[table] = node
167        elif isinstance(node, exp.Select) and len(tables) == 1:
168            # We can't push down window expressions
169            has_window_expression = any(
170                select for select in node.selects if select.find(exp.Window)
171            )
172            # we can't push down predicates to select statements if they are referenced in
173            # multiple places.
174            if (
175                not node.args.get("group")
176                and scope_ref_count[id(source)] < 2
177                and not has_window_expression
178            ):
179                nodes[table] = node
180    return nodes
181
182
183def replace_aliases(source, predicate):
184    aliases = {}
185
186    for select in source.selects:
187        if isinstance(select, exp.Alias):
188            aliases[select.alias] = select.this
189        else:
190            aliases[select.name] = select
191
192    def _replace_alias(column):
193        if isinstance(column, exp.Column) and column.name in aliases:
194            return aliases[column.name].copy()
195        return column
196
197    return predicate.transform(_replace_alias)
def pushdown_predicates(expression):
 8def pushdown_predicates(expression):
 9    """
10    Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS
11
12    Example:
13        >>> import sqlglot
14        >>> sql = "SELECT * FROM (SELECT * FROM x AS x) AS y WHERE y.a = 1"
15        >>> expression = sqlglot.parse_one(sql)
16        >>> pushdown_predicates(expression).sql()
17        'SELECT * FROM (SELECT * FROM x AS x WHERE y.a = 1) AS y WHERE TRUE'
18
19    Args:
20        expression (sqlglot.Expression): expression to optimize
21    Returns:
22        sqlglot.Expression: optimized expression
23    """
24    root = build_scope(expression)
25    scope_ref_count = root.ref_count()
26
27    for scope in reversed(list(root.traverse())):
28        select = scope.expression
29        where = select.args.get("where")
30        if where:
31            selected_sources = scope.selected_sources
32            # a right join can only push down to itself and not the source FROM table
33            for k, (node, source) in selected_sources.items():
34                parent = node.find_ancestor(exp.Join, exp.From)
35                if isinstance(parent, exp.Join) and parent.side == "RIGHT":
36                    selected_sources = {k: (node, source)}
37                    break
38            pushdown(where.this, selected_sources, scope_ref_count)
39
40        # joins should only pushdown into itself, not to other joins
41        # so we limit the selected sources to only itself
42        for join in select.args.get("joins") or []:
43            name = join.this.alias_or_name
44            pushdown(join.args.get("on"), {name: scope.selected_sources[name]}, scope_ref_count)
45
46    return expression

Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS

Example:
>>> import sqlglot
>>> sql = "SELECT * FROM (SELECT * FROM x AS x) AS y WHERE y.a = 1"
>>> expression = sqlglot.parse_one(sql)
>>> pushdown_predicates(expression).sql()
'SELECT * FROM (SELECT * FROM x AS x WHERE y.a = 1) AS y WHERE TRUE'
Arguments:
  • expression (sqlglot.Expression): expression to optimize
Returns:

sqlglot.Expression: optimized expression

def pushdown(condition, sources, scope_ref_count):
49def pushdown(condition, sources, scope_ref_count):
50    if not condition:
51        return
52
53    condition = condition.replace(simplify(condition))
54    cnf_like = normalized(condition) or not normalized(condition, dnf=True)
55
56    predicates = list(
57        condition.flatten()
58        if isinstance(condition, exp.And if cnf_like else exp.Or)
59        else [condition]
60    )
61
62    if cnf_like:
63        pushdown_cnf(predicates, sources, scope_ref_count)
64    else:
65        pushdown_dnf(predicates, sources, scope_ref_count)
def pushdown_cnf(predicates, scope, scope_ref_count):
68def pushdown_cnf(predicates, scope, scope_ref_count):
69    """
70    If the predicates are in CNF like form, we can simply replace each block in the parent.
71    """
72    for predicate in predicates:
73        for node in nodes_for_predicate(predicate, scope, scope_ref_count).values():
74            if isinstance(node, exp.Join):
75                predicate.replace(exp.true())
76                node.on(predicate, copy=False)
77                break
78            if isinstance(node, exp.Select):
79                predicate.replace(exp.true())
80                node.where(replace_aliases(node, predicate), copy=False)

If the predicates are in CNF like form, we can simply replace each block in the parent.

def pushdown_dnf(predicates, scope, scope_ref_count):
 83def pushdown_dnf(predicates, scope, scope_ref_count):
 84    """
 85    If the predicates are in DNF form, we can only push down conditions that are in all blocks.
 86    Additionally, we can't remove predicates from their original form.
 87    """
 88    # find all the tables that can be pushdown too
 89    # these are tables that are referenced in all blocks of a DNF
 90    # (a.x AND b.x) OR (a.y AND c.y)
 91    # only table a can be push down
 92    pushdown_tables = set()
 93
 94    for a in predicates:
 95        a_tables = set(exp.column_table_names(a))
 96
 97        for b in predicates:
 98            a_tables &= set(exp.column_table_names(b))
 99
100        pushdown_tables.update(a_tables)
101
102    conditions = {}
103
104    # for every pushdown table, find all related conditions in all predicates
105    # combine them with ORS
106    # (a.x AND and a.y AND b.x) OR (a.z AND c.y) -> (a.x AND a.y) OR (a.z)
107    for table in sorted(pushdown_tables):
108        for predicate in predicates:
109            nodes = nodes_for_predicate(predicate, scope, scope_ref_count)
110
111            if table not in nodes:
112                continue
113
114            predicate_condition = None
115
116            for column in predicate.find_all(exp.Column):
117                if column.table == table:
118                    condition = column.find_ancestor(exp.Condition)
119                    predicate_condition = (
120                        exp.and_(predicate_condition, condition)
121                        if predicate_condition
122                        else condition
123                    )
124
125            if predicate_condition:
126                conditions[table] = (
127                    exp.or_(conditions[table], predicate_condition)
128                    if table in conditions
129                    else predicate_condition
130                )
131
132        for name, node in nodes.items():
133            if name not in conditions:
134                continue
135
136            predicate = conditions[name]
137
138            if isinstance(node, exp.Join):
139                node.on(predicate, copy=False)
140            elif isinstance(node, exp.Select):
141                node.where(replace_aliases(node, predicate), copy=False)

If the predicates are in DNF form, we can only push down conditions that are in all blocks. Additionally, we can't remove predicates from their original form.

def nodes_for_predicate(predicate, sources, scope_ref_count):
144def nodes_for_predicate(predicate, sources, scope_ref_count):
145    nodes = {}
146    tables = exp.column_table_names(predicate)
147    where_condition = isinstance(predicate.find_ancestor(exp.Join, exp.Where), exp.Where)
148
149    for table in tables:
150        node, source = sources.get(table) or (None, None)
151
152        # if the predicate is in a where statement we can try to push it down
153        # we want to find the root join or from statement
154        if node and where_condition:
155            node = node.find_ancestor(exp.Join, exp.From)
156
157        # a node can reference a CTE which should be pushed down
158        if isinstance(node, exp.From) and not isinstance(source, exp.Table):
159            with_ = source.parent.expression.args.get("with")
160            if with_ and with_.recursive:
161                return {}
162            node = source.expression
163
164        if isinstance(node, exp.Join):
165            if node.side and node.side != "RIGHT":
166                return {}
167            nodes[table] = node
168        elif isinstance(node, exp.Select) and len(tables) == 1:
169            # We can't push down window expressions
170            has_window_expression = any(
171                select for select in node.selects if select.find(exp.Window)
172            )
173            # we can't push down predicates to select statements if they are referenced in
174            # multiple places.
175            if (
176                not node.args.get("group")
177                and scope_ref_count[id(source)] < 2
178                and not has_window_expression
179            ):
180                nodes[table] = node
181    return nodes
def replace_aliases(source, predicate):
184def replace_aliases(source, predicate):
185    aliases = {}
186
187    for select in source.selects:
188        if isinstance(select, exp.Alias):
189            aliases[select.alias] = select.this
190        else:
191            aliases[select.name] = select
192
193    def _replace_alias(column):
194        if isinstance(column, exp.Column) and column.name in aliases:
195            return aliases[column.name].copy()
196        return column
197
198    return predicate.transform(_replace_alias)