sqlglot.optimizer.pushdown_predicates
1from sqlglot import exp 2from sqlglot.optimizer.normalize import normalized 3from sqlglot.optimizer.scope import build_scope, find_in_scope 4from sqlglot.optimizer.simplify import simplify 5 6 7def pushdown_predicates(expression): 8 """ 9 Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS 10 11 Example: 12 >>> import sqlglot 13 >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y WHERE y.a = 1" 14 >>> expression = sqlglot.parse_one(sql) 15 >>> pushdown_predicates(expression).sql() 16 'SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x WHERE x.a = 1) AS y WHERE TRUE' 17 18 Args: 19 expression (sqlglot.Expression): expression to optimize 20 Returns: 21 sqlglot.Expression: optimized expression 22 """ 23 root = build_scope(expression) 24 25 if root: 26 scope_ref_count = root.ref_count() 27 28 for scope in reversed(list(root.traverse())): 29 select = scope.expression 30 where = select.args.get("where") 31 if where: 32 selected_sources = scope.selected_sources 33 # a right join can only push down to itself and not the source FROM table 34 for k, (node, source) in selected_sources.items(): 35 parent = node.find_ancestor(exp.Join, exp.From) 36 if isinstance(parent, exp.Join) and parent.side == "RIGHT": 37 selected_sources = {k: (node, source)} 38 break 39 pushdown(where.this, selected_sources, scope_ref_count) 40 41 # joins should only pushdown into itself, not to other joins 42 # so we limit the selected sources to only itself 43 for join in select.args.get("joins") or []: 44 name = join.alias_or_name 45 if name in scope.selected_sources: 46 pushdown( 47 join.args.get("on"), {name: scope.selected_sources[name]}, scope_ref_count 48 ) 49 50 return expression 51 52 53def pushdown(condition, sources, scope_ref_count): 54 if not condition: 55 return 56 57 condition = condition.replace(simplify(condition)) 58 cnf_like = normalized(condition) or not normalized(condition, dnf=True) 59 60 predicates = list( 61 condition.flatten() 62 if isinstance(condition, exp.And if cnf_like else exp.Or) 63 else [condition] 64 ) 65 66 if cnf_like: 67 pushdown_cnf(predicates, sources, scope_ref_count) 68 else: 69 pushdown_dnf(predicates, sources, scope_ref_count) 70 71 72def pushdown_cnf(predicates, scope, scope_ref_count): 73 """ 74 If the predicates are in CNF like form, we can simply replace each block in the parent. 75 """ 76 for predicate in predicates: 77 for node in nodes_for_predicate(predicate, scope, scope_ref_count).values(): 78 if isinstance(node, exp.Join): 79 predicate.replace(exp.true()) 80 node.on(predicate, copy=False) 81 break 82 if isinstance(node, exp.Select): 83 predicate.replace(exp.true()) 84 inner_predicate = replace_aliases(node, predicate) 85 if find_in_scope(inner_predicate, exp.AggFunc): 86 node.having(inner_predicate, copy=False) 87 else: 88 node.where(inner_predicate, copy=False) 89 90 91def pushdown_dnf(predicates, scope, scope_ref_count): 92 """ 93 If the predicates are in DNF form, we can only push down conditions that are in all blocks. 94 Additionally, we can't remove predicates from their original form. 95 """ 96 # find all the tables that can be pushdown too 97 # these are tables that are referenced in all blocks of a DNF 98 # (a.x AND b.x) OR (a.y AND c.y) 99 # only table a can be push down 100 pushdown_tables = set() 101 102 for a in predicates: 103 a_tables = exp.column_table_names(a) 104 105 for b in predicates: 106 a_tables &= exp.column_table_names(b) 107 108 pushdown_tables.update(a_tables) 109 110 conditions = {} 111 112 # for every pushdown table, find all related conditions in all predicates 113 # combine them with ORS 114 # (a.x AND and a.y AND b.x) OR (a.z AND c.y) -> (a.x AND a.y) OR (a.z) 115 for table in sorted(pushdown_tables): 116 for predicate in predicates: 117 nodes = nodes_for_predicate(predicate, scope, scope_ref_count) 118 119 if table not in nodes: 120 continue 121 122 predicate_condition = None 123 124 for column in predicate.find_all(exp.Column): 125 if column.table == table: 126 condition = column.find_ancestor(exp.Condition) 127 predicate_condition = ( 128 exp.and_(predicate_condition, condition) 129 if predicate_condition 130 else condition 131 ) 132 133 if predicate_condition: 134 conditions[table] = ( 135 exp.or_(conditions[table], predicate_condition) 136 if table in conditions 137 else predicate_condition 138 ) 139 140 for name, node in nodes.items(): 141 if name not in conditions: 142 continue 143 144 predicate = conditions[name] 145 146 if isinstance(node, exp.Join): 147 node.on(predicate, copy=False) 148 elif isinstance(node, exp.Select): 149 inner_predicate = replace_aliases(node, predicate) 150 if find_in_scope(inner_predicate, exp.AggFunc): 151 node.having(inner_predicate, copy=False) 152 else: 153 node.where(inner_predicate, copy=False) 154 155 156def nodes_for_predicate(predicate, sources, scope_ref_count): 157 nodes = {} 158 tables = exp.column_table_names(predicate) 159 where_condition = isinstance(predicate.find_ancestor(exp.Join, exp.Where), exp.Where) 160 161 for table in sorted(tables): 162 node, source = sources.get(table) or (None, None) 163 164 # if the predicate is in a where statement we can try to push it down 165 # we want to find the root join or from statement 166 if node and where_condition: 167 node = node.find_ancestor(exp.Join, exp.From) 168 169 # a node can reference a CTE which should be pushed down 170 if isinstance(node, exp.From) and not isinstance(source, exp.Table): 171 with_ = source.parent.expression.args.get("with") 172 if with_ and with_.recursive: 173 return {} 174 node = source.expression 175 176 if isinstance(node, exp.Join): 177 if node.side and node.side != "RIGHT": 178 return {} 179 nodes[table] = node 180 elif isinstance(node, exp.Select) and len(tables) == 1: 181 # We can't push down window expressions 182 has_window_expression = any( 183 select for select in node.selects if select.find(exp.Window) 184 ) 185 # we can't push down predicates to select statements if they are referenced in 186 # multiple places. 187 if ( 188 not node.args.get("group") 189 and scope_ref_count[id(source)] < 2 190 and not has_window_expression 191 ): 192 nodes[table] = node 193 return nodes 194 195 196def replace_aliases(source, predicate): 197 aliases = {} 198 199 for select in source.selects: 200 if isinstance(select, exp.Alias): 201 aliases[select.alias] = select.this 202 else: 203 aliases[select.name] = select 204 205 def _replace_alias(column): 206 if isinstance(column, exp.Column) and column.name in aliases: 207 return aliases[column.name].copy() 208 return column 209 210 return predicate.transform(_replace_alias)
def
pushdown_predicates(expression):
8def pushdown_predicates(expression): 9 """ 10 Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS 11 12 Example: 13 >>> import sqlglot 14 >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y WHERE y.a = 1" 15 >>> expression = sqlglot.parse_one(sql) 16 >>> pushdown_predicates(expression).sql() 17 'SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x WHERE x.a = 1) AS y WHERE TRUE' 18 19 Args: 20 expression (sqlglot.Expression): expression to optimize 21 Returns: 22 sqlglot.Expression: optimized expression 23 """ 24 root = build_scope(expression) 25 26 if root: 27 scope_ref_count = root.ref_count() 28 29 for scope in reversed(list(root.traverse())): 30 select = scope.expression 31 where = select.args.get("where") 32 if where: 33 selected_sources = scope.selected_sources 34 # a right join can only push down to itself and not the source FROM table 35 for k, (node, source) in selected_sources.items(): 36 parent = node.find_ancestor(exp.Join, exp.From) 37 if isinstance(parent, exp.Join) and parent.side == "RIGHT": 38 selected_sources = {k: (node, source)} 39 break 40 pushdown(where.this, selected_sources, scope_ref_count) 41 42 # joins should only pushdown into itself, not to other joins 43 # so we limit the selected sources to only itself 44 for join in select.args.get("joins") or []: 45 name = join.alias_or_name 46 if name in scope.selected_sources: 47 pushdown( 48 join.args.get("on"), {name: scope.selected_sources[name]}, scope_ref_count 49 ) 50 51 return expression
Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS
Example:
>>> import sqlglot >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y WHERE y.a = 1" >>> expression = sqlglot.parse_one(sql) >>> pushdown_predicates(expression).sql() 'SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x WHERE x.a = 1) AS y WHERE TRUE'
Arguments:
- expression (sqlglot.Expression): expression to optimize
Returns:
sqlglot.Expression: optimized expression
def
pushdown(condition, sources, scope_ref_count):
54def pushdown(condition, sources, scope_ref_count): 55 if not condition: 56 return 57 58 condition = condition.replace(simplify(condition)) 59 cnf_like = normalized(condition) or not normalized(condition, dnf=True) 60 61 predicates = list( 62 condition.flatten() 63 if isinstance(condition, exp.And if cnf_like else exp.Or) 64 else [condition] 65 ) 66 67 if cnf_like: 68 pushdown_cnf(predicates, sources, scope_ref_count) 69 else: 70 pushdown_dnf(predicates, sources, scope_ref_count)
def
pushdown_cnf(predicates, scope, scope_ref_count):
73def pushdown_cnf(predicates, scope, scope_ref_count): 74 """ 75 If the predicates are in CNF like form, we can simply replace each block in the parent. 76 """ 77 for predicate in predicates: 78 for node in nodes_for_predicate(predicate, scope, scope_ref_count).values(): 79 if isinstance(node, exp.Join): 80 predicate.replace(exp.true()) 81 node.on(predicate, copy=False) 82 break 83 if isinstance(node, exp.Select): 84 predicate.replace(exp.true()) 85 inner_predicate = replace_aliases(node, predicate) 86 if find_in_scope(inner_predicate, exp.AggFunc): 87 node.having(inner_predicate, copy=False) 88 else: 89 node.where(inner_predicate, copy=False)
If the predicates are in CNF like form, we can simply replace each block in the parent.
def
pushdown_dnf(predicates, scope, scope_ref_count):
92def pushdown_dnf(predicates, scope, scope_ref_count): 93 """ 94 If the predicates are in DNF form, we can only push down conditions that are in all blocks. 95 Additionally, we can't remove predicates from their original form. 96 """ 97 # find all the tables that can be pushdown too 98 # these are tables that are referenced in all blocks of a DNF 99 # (a.x AND b.x) OR (a.y AND c.y) 100 # only table a can be push down 101 pushdown_tables = set() 102 103 for a in predicates: 104 a_tables = exp.column_table_names(a) 105 106 for b in predicates: 107 a_tables &= exp.column_table_names(b) 108 109 pushdown_tables.update(a_tables) 110 111 conditions = {} 112 113 # for every pushdown table, find all related conditions in all predicates 114 # combine them with ORS 115 # (a.x AND and a.y AND b.x) OR (a.z AND c.y) -> (a.x AND a.y) OR (a.z) 116 for table in sorted(pushdown_tables): 117 for predicate in predicates: 118 nodes = nodes_for_predicate(predicate, scope, scope_ref_count) 119 120 if table not in nodes: 121 continue 122 123 predicate_condition = None 124 125 for column in predicate.find_all(exp.Column): 126 if column.table == table: 127 condition = column.find_ancestor(exp.Condition) 128 predicate_condition = ( 129 exp.and_(predicate_condition, condition) 130 if predicate_condition 131 else condition 132 ) 133 134 if predicate_condition: 135 conditions[table] = ( 136 exp.or_(conditions[table], predicate_condition) 137 if table in conditions 138 else predicate_condition 139 ) 140 141 for name, node in nodes.items(): 142 if name not in conditions: 143 continue 144 145 predicate = conditions[name] 146 147 if isinstance(node, exp.Join): 148 node.on(predicate, copy=False) 149 elif isinstance(node, exp.Select): 150 inner_predicate = replace_aliases(node, predicate) 151 if find_in_scope(inner_predicate, exp.AggFunc): 152 node.having(inner_predicate, copy=False) 153 else: 154 node.where(inner_predicate, copy=False)
If the predicates are in DNF form, we can only push down conditions that are in all blocks. Additionally, we can't remove predicates from their original form.
def
nodes_for_predicate(predicate, sources, scope_ref_count):
157def nodes_for_predicate(predicate, sources, scope_ref_count): 158 nodes = {} 159 tables = exp.column_table_names(predicate) 160 where_condition = isinstance(predicate.find_ancestor(exp.Join, exp.Where), exp.Where) 161 162 for table in sorted(tables): 163 node, source = sources.get(table) or (None, None) 164 165 # if the predicate is in a where statement we can try to push it down 166 # we want to find the root join or from statement 167 if node and where_condition: 168 node = node.find_ancestor(exp.Join, exp.From) 169 170 # a node can reference a CTE which should be pushed down 171 if isinstance(node, exp.From) and not isinstance(source, exp.Table): 172 with_ = source.parent.expression.args.get("with") 173 if with_ and with_.recursive: 174 return {} 175 node = source.expression 176 177 if isinstance(node, exp.Join): 178 if node.side and node.side != "RIGHT": 179 return {} 180 nodes[table] = node 181 elif isinstance(node, exp.Select) and len(tables) == 1: 182 # We can't push down window expressions 183 has_window_expression = any( 184 select for select in node.selects if select.find(exp.Window) 185 ) 186 # we can't push down predicates to select statements if they are referenced in 187 # multiple places. 188 if ( 189 not node.args.get("group") 190 and scope_ref_count[id(source)] < 2 191 and not has_window_expression 192 ): 193 nodes[table] = node 194 return nodes
def
replace_aliases(source, predicate):
197def replace_aliases(source, predicate): 198 aliases = {} 199 200 for select in source.selects: 201 if isinstance(select, exp.Alias): 202 aliases[select.alias] = select.this 203 else: 204 aliases[select.name] = select 205 206 def _replace_alias(column): 207 if isinstance(column, exp.Column) and column.name in aliases: 208 return aliases[column.name].copy() 209 return column 210 211 return predicate.transform(_replace_alias)