Edit on GitHub

sqlglot.optimizer.pushdown_predicates

  1from sqlglot import exp
  2from sqlglot.optimizer.normalize import normalized
  3from sqlglot.optimizer.scope import build_scope
  4from sqlglot.optimizer.simplify import simplify
  5
  6
  7def pushdown_predicates(expression):
  8    """
  9    Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS
 10
 11    Example:
 12        >>> import sqlglot
 13        >>> sql = "SELECT * FROM (SELECT * FROM x AS x) AS y WHERE y.a = 1"
 14        >>> expression = sqlglot.parse_one(sql)
 15        >>> pushdown_predicates(expression).sql()
 16        'SELECT * FROM (SELECT * FROM x AS x WHERE y.a = 1) AS y WHERE TRUE'
 17
 18    Args:
 19        expression (sqlglot.Expression): expression to optimize
 20    Returns:
 21        sqlglot.Expression: optimized expression
 22    """
 23    root = build_scope(expression)
 24
 25    if root:
 26        scope_ref_count = root.ref_count()
 27
 28        for scope in reversed(list(root.traverse())):
 29            select = scope.expression
 30            where = select.args.get("where")
 31            if where:
 32                selected_sources = scope.selected_sources
 33                # a right join can only push down to itself and not the source FROM table
 34                for k, (node, source) in selected_sources.items():
 35                    parent = node.find_ancestor(exp.Join, exp.From)
 36                    if isinstance(parent, exp.Join) and parent.side == "RIGHT":
 37                        selected_sources = {k: (node, source)}
 38                        break
 39                pushdown(where.this, selected_sources, scope_ref_count)
 40
 41            # joins should only pushdown into itself, not to other joins
 42            # so we limit the selected sources to only itself
 43            for join in select.args.get("joins") or []:
 44                name = join.this.alias_or_name
 45                pushdown(join.args.get("on"), {name: scope.selected_sources[name]}, scope_ref_count)
 46
 47    return expression
 48
 49
 50def pushdown(condition, sources, scope_ref_count):
 51    if not condition:
 52        return
 53
 54    condition = condition.replace(simplify(condition))
 55    cnf_like = normalized(condition) or not normalized(condition, dnf=True)
 56
 57    predicates = list(
 58        condition.flatten()
 59        if isinstance(condition, exp.And if cnf_like else exp.Or)
 60        else [condition]
 61    )
 62
 63    if cnf_like:
 64        pushdown_cnf(predicates, sources, scope_ref_count)
 65    else:
 66        pushdown_dnf(predicates, sources, scope_ref_count)
 67
 68
 69def pushdown_cnf(predicates, scope, scope_ref_count):
 70    """
 71    If the predicates are in CNF like form, we can simply replace each block in the parent.
 72    """
 73    for predicate in predicates:
 74        for node in nodes_for_predicate(predicate, scope, scope_ref_count).values():
 75            if isinstance(node, exp.Join):
 76                predicate.replace(exp.true())
 77                node.on(predicate, copy=False)
 78                break
 79            if isinstance(node, exp.Select):
 80                predicate.replace(exp.true())
 81                node.where(replace_aliases(node, predicate), copy=False)
 82
 83
 84def pushdown_dnf(predicates, scope, scope_ref_count):
 85    """
 86    If the predicates are in DNF form, we can only push down conditions that are in all blocks.
 87    Additionally, we can't remove predicates from their original form.
 88    """
 89    # find all the tables that can be pushdown too
 90    # these are tables that are referenced in all blocks of a DNF
 91    # (a.x AND b.x) OR (a.y AND c.y)
 92    # only table a can be push down
 93    pushdown_tables = set()
 94
 95    for a in predicates:
 96        a_tables = set(exp.column_table_names(a))
 97
 98        for b in predicates:
 99            a_tables &= set(exp.column_table_names(b))
100
101        pushdown_tables.update(a_tables)
102
103    conditions = {}
104
105    # for every pushdown table, find all related conditions in all predicates
106    # combine them with ORS
107    # (a.x AND and a.y AND b.x) OR (a.z AND c.y) -> (a.x AND a.y) OR (a.z)
108    for table in sorted(pushdown_tables):
109        for predicate in predicates:
110            nodes = nodes_for_predicate(predicate, scope, scope_ref_count)
111
112            if table not in nodes:
113                continue
114
115            predicate_condition = None
116
117            for column in predicate.find_all(exp.Column):
118                if column.table == table:
119                    condition = column.find_ancestor(exp.Condition)
120                    predicate_condition = (
121                        exp.and_(predicate_condition, condition)
122                        if predicate_condition
123                        else condition
124                    )
125
126            if predicate_condition:
127                conditions[table] = (
128                    exp.or_(conditions[table], predicate_condition)
129                    if table in conditions
130                    else predicate_condition
131                )
132
133        for name, node in nodes.items():
134            if name not in conditions:
135                continue
136
137            predicate = conditions[name]
138
139            if isinstance(node, exp.Join):
140                node.on(predicate, copy=False)
141            elif isinstance(node, exp.Select):
142                node.where(replace_aliases(node, predicate), copy=False)
143
144
145def nodes_for_predicate(predicate, sources, scope_ref_count):
146    nodes = {}
147    tables = exp.column_table_names(predicate)
148    where_condition = isinstance(predicate.find_ancestor(exp.Join, exp.Where), exp.Where)
149
150    for table in tables:
151        node, source = sources.get(table) or (None, None)
152
153        # if the predicate is in a where statement we can try to push it down
154        # we want to find the root join or from statement
155        if node and where_condition:
156            node = node.find_ancestor(exp.Join, exp.From)
157
158        # a node can reference a CTE which should be pushed down
159        if isinstance(node, exp.From) and not isinstance(source, exp.Table):
160            with_ = source.parent.expression.args.get("with")
161            if with_ and with_.recursive:
162                return {}
163            node = source.expression
164
165        if isinstance(node, exp.Join):
166            if node.side and node.side != "RIGHT":
167                return {}
168            nodes[table] = node
169        elif isinstance(node, exp.Select) and len(tables) == 1:
170            # We can't push down window expressions
171            has_window_expression = any(
172                select for select in node.selects if select.find(exp.Window)
173            )
174            # we can't push down predicates to select statements if they are referenced in
175            # multiple places.
176            if (
177                not node.args.get("group")
178                and scope_ref_count[id(source)] < 2
179                and not has_window_expression
180            ):
181                nodes[table] = node
182    return nodes
183
184
185def replace_aliases(source, predicate):
186    aliases = {}
187
188    for select in source.selects:
189        if isinstance(select, exp.Alias):
190            aliases[select.alias] = select.this
191        else:
192            aliases[select.name] = select
193
194    def _replace_alias(column):
195        if isinstance(column, exp.Column) and column.name in aliases:
196            return aliases[column.name].copy()
197        return column
198
199    return predicate.transform(_replace_alias)
def pushdown_predicates(expression):
 8def pushdown_predicates(expression):
 9    """
10    Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS
11
12    Example:
13        >>> import sqlglot
14        >>> sql = "SELECT * FROM (SELECT * FROM x AS x) AS y WHERE y.a = 1"
15        >>> expression = sqlglot.parse_one(sql)
16        >>> pushdown_predicates(expression).sql()
17        'SELECT * FROM (SELECT * FROM x AS x WHERE y.a = 1) AS y WHERE TRUE'
18
19    Args:
20        expression (sqlglot.Expression): expression to optimize
21    Returns:
22        sqlglot.Expression: optimized expression
23    """
24    root = build_scope(expression)
25
26    if root:
27        scope_ref_count = root.ref_count()
28
29        for scope in reversed(list(root.traverse())):
30            select = scope.expression
31            where = select.args.get("where")
32            if where:
33                selected_sources = scope.selected_sources
34                # a right join can only push down to itself and not the source FROM table
35                for k, (node, source) in selected_sources.items():
36                    parent = node.find_ancestor(exp.Join, exp.From)
37                    if isinstance(parent, exp.Join) and parent.side == "RIGHT":
38                        selected_sources = {k: (node, source)}
39                        break
40                pushdown(where.this, selected_sources, scope_ref_count)
41
42            # joins should only pushdown into itself, not to other joins
43            # so we limit the selected sources to only itself
44            for join in select.args.get("joins") or []:
45                name = join.this.alias_or_name
46                pushdown(join.args.get("on"), {name: scope.selected_sources[name]}, scope_ref_count)
47
48    return expression

Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS

Example:
>>> import sqlglot
>>> sql = "SELECT * FROM (SELECT * FROM x AS x) AS y WHERE y.a = 1"
>>> expression = sqlglot.parse_one(sql)
>>> pushdown_predicates(expression).sql()
'SELECT * FROM (SELECT * FROM x AS x WHERE y.a = 1) AS y WHERE TRUE'
Arguments:
  • expression (sqlglot.Expression): expression to optimize
Returns:

sqlglot.Expression: optimized expression

def pushdown(condition, sources, scope_ref_count):
51def pushdown(condition, sources, scope_ref_count):
52    if not condition:
53        return
54
55    condition = condition.replace(simplify(condition))
56    cnf_like = normalized(condition) or not normalized(condition, dnf=True)
57
58    predicates = list(
59        condition.flatten()
60        if isinstance(condition, exp.And if cnf_like else exp.Or)
61        else [condition]
62    )
63
64    if cnf_like:
65        pushdown_cnf(predicates, sources, scope_ref_count)
66    else:
67        pushdown_dnf(predicates, sources, scope_ref_count)
def pushdown_cnf(predicates, scope, scope_ref_count):
70def pushdown_cnf(predicates, scope, scope_ref_count):
71    """
72    If the predicates are in CNF like form, we can simply replace each block in the parent.
73    """
74    for predicate in predicates:
75        for node in nodes_for_predicate(predicate, scope, scope_ref_count).values():
76            if isinstance(node, exp.Join):
77                predicate.replace(exp.true())
78                node.on(predicate, copy=False)
79                break
80            if isinstance(node, exp.Select):
81                predicate.replace(exp.true())
82                node.where(replace_aliases(node, predicate), copy=False)

If the predicates are in CNF like form, we can simply replace each block in the parent.

def pushdown_dnf(predicates, scope, scope_ref_count):
 85def pushdown_dnf(predicates, scope, scope_ref_count):
 86    """
 87    If the predicates are in DNF form, we can only push down conditions that are in all blocks.
 88    Additionally, we can't remove predicates from their original form.
 89    """
 90    # find all the tables that can be pushdown too
 91    # these are tables that are referenced in all blocks of a DNF
 92    # (a.x AND b.x) OR (a.y AND c.y)
 93    # only table a can be push down
 94    pushdown_tables = set()
 95
 96    for a in predicates:
 97        a_tables = set(exp.column_table_names(a))
 98
 99        for b in predicates:
100            a_tables &= set(exp.column_table_names(b))
101
102        pushdown_tables.update(a_tables)
103
104    conditions = {}
105
106    # for every pushdown table, find all related conditions in all predicates
107    # combine them with ORS
108    # (a.x AND and a.y AND b.x) OR (a.z AND c.y) -> (a.x AND a.y) OR (a.z)
109    for table in sorted(pushdown_tables):
110        for predicate in predicates:
111            nodes = nodes_for_predicate(predicate, scope, scope_ref_count)
112
113            if table not in nodes:
114                continue
115
116            predicate_condition = None
117
118            for column in predicate.find_all(exp.Column):
119                if column.table == table:
120                    condition = column.find_ancestor(exp.Condition)
121                    predicate_condition = (
122                        exp.and_(predicate_condition, condition)
123                        if predicate_condition
124                        else condition
125                    )
126
127            if predicate_condition:
128                conditions[table] = (
129                    exp.or_(conditions[table], predicate_condition)
130                    if table in conditions
131                    else predicate_condition
132                )
133
134        for name, node in nodes.items():
135            if name not in conditions:
136                continue
137
138            predicate = conditions[name]
139
140            if isinstance(node, exp.Join):
141                node.on(predicate, copy=False)
142            elif isinstance(node, exp.Select):
143                node.where(replace_aliases(node, predicate), copy=False)

If the predicates are in DNF form, we can only push down conditions that are in all blocks. Additionally, we can't remove predicates from their original form.

def nodes_for_predicate(predicate, sources, scope_ref_count):
146def nodes_for_predicate(predicate, sources, scope_ref_count):
147    nodes = {}
148    tables = exp.column_table_names(predicate)
149    where_condition = isinstance(predicate.find_ancestor(exp.Join, exp.Where), exp.Where)
150
151    for table in tables:
152        node, source = sources.get(table) or (None, None)
153
154        # if the predicate is in a where statement we can try to push it down
155        # we want to find the root join or from statement
156        if node and where_condition:
157            node = node.find_ancestor(exp.Join, exp.From)
158
159        # a node can reference a CTE which should be pushed down
160        if isinstance(node, exp.From) and not isinstance(source, exp.Table):
161            with_ = source.parent.expression.args.get("with")
162            if with_ and with_.recursive:
163                return {}
164            node = source.expression
165
166        if isinstance(node, exp.Join):
167            if node.side and node.side != "RIGHT":
168                return {}
169            nodes[table] = node
170        elif isinstance(node, exp.Select) and len(tables) == 1:
171            # We can't push down window expressions
172            has_window_expression = any(
173                select for select in node.selects if select.find(exp.Window)
174            )
175            # we can't push down predicates to select statements if they are referenced in
176            # multiple places.
177            if (
178                not node.args.get("group")
179                and scope_ref_count[id(source)] < 2
180                and not has_window_expression
181            ):
182                nodes[table] = node
183    return nodes
def replace_aliases(source, predicate):
186def replace_aliases(source, predicate):
187    aliases = {}
188
189    for select in source.selects:
190        if isinstance(select, exp.Alias):
191            aliases[select.alias] = select.this
192        else:
193            aliases[select.name] = select
194
195    def _replace_alias(column):
196        if isinstance(column, exp.Column) and column.name in aliases:
197            return aliases[column.name].copy()
198        return column
199
200    return predicate.transform(_replace_alias)