sqlglot.optimizer.pushdown_predicates
1from sqlglot import exp 2from sqlglot.optimizer.normalize import normalized 3from sqlglot.optimizer.scope import build_scope 4from sqlglot.optimizer.simplify import simplify 5 6 7def pushdown_predicates(expression): 8 """ 9 Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS 10 11 Example: 12 >>> import sqlglot 13 >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y WHERE y.a = 1" 14 >>> expression = sqlglot.parse_one(sql) 15 >>> pushdown_predicates(expression).sql() 16 'SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x WHERE x.a = 1) AS y WHERE TRUE' 17 18 Args: 19 expression (sqlglot.Expression): expression to optimize 20 Returns: 21 sqlglot.Expression: optimized expression 22 """ 23 root = build_scope(expression) 24 25 if root: 26 scope_ref_count = root.ref_count() 27 28 for scope in reversed(list(root.traverse())): 29 select = scope.expression 30 where = select.args.get("where") 31 if where: 32 selected_sources = scope.selected_sources 33 # a right join can only push down to itself and not the source FROM table 34 for k, (node, source) in selected_sources.items(): 35 parent = node.find_ancestor(exp.Join, exp.From) 36 if isinstance(parent, exp.Join) and parent.side == "RIGHT": 37 selected_sources = {k: (node, source)} 38 break 39 pushdown(where.this, selected_sources, scope_ref_count) 40 41 # joins should only pushdown into itself, not to other joins 42 # so we limit the selected sources to only itself 43 for join in select.args.get("joins") or []: 44 name = join.alias_or_name 45 if name in scope.selected_sources: 46 pushdown( 47 join.args.get("on"), {name: scope.selected_sources[name]}, scope_ref_count 48 ) 49 50 return expression 51 52 53def pushdown(condition, sources, scope_ref_count): 54 if not condition: 55 return 56 57 condition = condition.replace(simplify(condition)) 58 cnf_like = normalized(condition) or not normalized(condition, dnf=True) 59 60 predicates = list( 61 condition.flatten() 62 if isinstance(condition, exp.And if cnf_like else exp.Or) 63 else [condition] 64 ) 65 66 if cnf_like: 67 pushdown_cnf(predicates, sources, scope_ref_count) 68 else: 69 pushdown_dnf(predicates, sources, scope_ref_count) 70 71 72def pushdown_cnf(predicates, scope, scope_ref_count): 73 """ 74 If the predicates are in CNF like form, we can simply replace each block in the parent. 75 """ 76 for predicate in predicates: 77 for node in nodes_for_predicate(predicate, scope, scope_ref_count).values(): 78 if isinstance(node, exp.Join): 79 predicate.replace(exp.true()) 80 node.on(predicate, copy=False) 81 break 82 if isinstance(node, exp.Select): 83 predicate.replace(exp.true()) 84 node.where(replace_aliases(node, predicate), copy=False) 85 86 87def pushdown_dnf(predicates, scope, scope_ref_count): 88 """ 89 If the predicates are in DNF form, we can only push down conditions that are in all blocks. 90 Additionally, we can't remove predicates from their original form. 91 """ 92 # find all the tables that can be pushdown too 93 # these are tables that are referenced in all blocks of a DNF 94 # (a.x AND b.x) OR (a.y AND c.y) 95 # only table a can be push down 96 pushdown_tables = set() 97 98 for a in predicates: 99 a_tables = exp.column_table_names(a) 100 101 for b in predicates: 102 a_tables &= exp.column_table_names(b) 103 104 pushdown_tables.update(a_tables) 105 106 conditions = {} 107 108 # for every pushdown table, find all related conditions in all predicates 109 # combine them with ORS 110 # (a.x AND and a.y AND b.x) OR (a.z AND c.y) -> (a.x AND a.y) OR (a.z) 111 for table in sorted(pushdown_tables): 112 for predicate in predicates: 113 nodes = nodes_for_predicate(predicate, scope, scope_ref_count) 114 115 if table not in nodes: 116 continue 117 118 predicate_condition = None 119 120 for column in predicate.find_all(exp.Column): 121 if column.table == table: 122 condition = column.find_ancestor(exp.Condition) 123 predicate_condition = ( 124 exp.and_(predicate_condition, condition) 125 if predicate_condition 126 else condition 127 ) 128 129 if predicate_condition: 130 conditions[table] = ( 131 exp.or_(conditions[table], predicate_condition) 132 if table in conditions 133 else predicate_condition 134 ) 135 136 for name, node in nodes.items(): 137 if name not in conditions: 138 continue 139 140 predicate = conditions[name] 141 142 if isinstance(node, exp.Join): 143 node.on(predicate, copy=False) 144 elif isinstance(node, exp.Select): 145 node.where(replace_aliases(node, predicate), copy=False) 146 147 148def nodes_for_predicate(predicate, sources, scope_ref_count): 149 nodes = {} 150 tables = exp.column_table_names(predicate) 151 where_condition = isinstance(predicate.find_ancestor(exp.Join, exp.Where), exp.Where) 152 153 for table in sorted(tables): 154 node, source = sources.get(table) or (None, None) 155 156 # if the predicate is in a where statement we can try to push it down 157 # we want to find the root join or from statement 158 if node and where_condition: 159 node = node.find_ancestor(exp.Join, exp.From) 160 161 # a node can reference a CTE which should be pushed down 162 if isinstance(node, exp.From) and not isinstance(source, exp.Table): 163 with_ = source.parent.expression.args.get("with") 164 if with_ and with_.recursive: 165 return {} 166 node = source.expression 167 168 if isinstance(node, exp.Join): 169 if node.side and node.side != "RIGHT": 170 return {} 171 nodes[table] = node 172 elif isinstance(node, exp.Select) and len(tables) == 1: 173 # We can't push down window expressions 174 has_window_expression = any( 175 select for select in node.selects if select.find(exp.Window) 176 ) 177 # we can't push down predicates to select statements if they are referenced in 178 # multiple places. 179 if ( 180 not node.args.get("group") 181 and scope_ref_count[id(source)] < 2 182 and not has_window_expression 183 ): 184 nodes[table] = node 185 return nodes 186 187 188def replace_aliases(source, predicate): 189 aliases = {} 190 191 for select in source.selects: 192 if isinstance(select, exp.Alias): 193 aliases[select.alias] = select.this 194 else: 195 aliases[select.name] = select 196 197 def _replace_alias(column): 198 if isinstance(column, exp.Column) and column.name in aliases: 199 return aliases[column.name].copy() 200 return column 201 202 return predicate.transform(_replace_alias)
def
pushdown_predicates(expression):
8def pushdown_predicates(expression): 9 """ 10 Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS 11 12 Example: 13 >>> import sqlglot 14 >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y WHERE y.a = 1" 15 >>> expression = sqlglot.parse_one(sql) 16 >>> pushdown_predicates(expression).sql() 17 'SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x WHERE x.a = 1) AS y WHERE TRUE' 18 19 Args: 20 expression (sqlglot.Expression): expression to optimize 21 Returns: 22 sqlglot.Expression: optimized expression 23 """ 24 root = build_scope(expression) 25 26 if root: 27 scope_ref_count = root.ref_count() 28 29 for scope in reversed(list(root.traverse())): 30 select = scope.expression 31 where = select.args.get("where") 32 if where: 33 selected_sources = scope.selected_sources 34 # a right join can only push down to itself and not the source FROM table 35 for k, (node, source) in selected_sources.items(): 36 parent = node.find_ancestor(exp.Join, exp.From) 37 if isinstance(parent, exp.Join) and parent.side == "RIGHT": 38 selected_sources = {k: (node, source)} 39 break 40 pushdown(where.this, selected_sources, scope_ref_count) 41 42 # joins should only pushdown into itself, not to other joins 43 # so we limit the selected sources to only itself 44 for join in select.args.get("joins") or []: 45 name = join.alias_or_name 46 if name in scope.selected_sources: 47 pushdown( 48 join.args.get("on"), {name: scope.selected_sources[name]}, scope_ref_count 49 ) 50 51 return expression
Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS
Example:
>>> import sqlglot >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y WHERE y.a = 1" >>> expression = sqlglot.parse_one(sql) >>> pushdown_predicates(expression).sql() 'SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x WHERE x.a = 1) AS y WHERE TRUE'
Arguments:
- expression (sqlglot.Expression): expression to optimize
Returns:
sqlglot.Expression: optimized expression
def
pushdown(condition, sources, scope_ref_count):
54def pushdown(condition, sources, scope_ref_count): 55 if not condition: 56 return 57 58 condition = condition.replace(simplify(condition)) 59 cnf_like = normalized(condition) or not normalized(condition, dnf=True) 60 61 predicates = list( 62 condition.flatten() 63 if isinstance(condition, exp.And if cnf_like else exp.Or) 64 else [condition] 65 ) 66 67 if cnf_like: 68 pushdown_cnf(predicates, sources, scope_ref_count) 69 else: 70 pushdown_dnf(predicates, sources, scope_ref_count)
def
pushdown_cnf(predicates, scope, scope_ref_count):
73def pushdown_cnf(predicates, scope, scope_ref_count): 74 """ 75 If the predicates are in CNF like form, we can simply replace each block in the parent. 76 """ 77 for predicate in predicates: 78 for node in nodes_for_predicate(predicate, scope, scope_ref_count).values(): 79 if isinstance(node, exp.Join): 80 predicate.replace(exp.true()) 81 node.on(predicate, copy=False) 82 break 83 if isinstance(node, exp.Select): 84 predicate.replace(exp.true()) 85 node.where(replace_aliases(node, predicate), copy=False)
If the predicates are in CNF like form, we can simply replace each block in the parent.
def
pushdown_dnf(predicates, scope, scope_ref_count):
88def pushdown_dnf(predicates, scope, scope_ref_count): 89 """ 90 If the predicates are in DNF form, we can only push down conditions that are in all blocks. 91 Additionally, we can't remove predicates from their original form. 92 """ 93 # find all the tables that can be pushdown too 94 # these are tables that are referenced in all blocks of a DNF 95 # (a.x AND b.x) OR (a.y AND c.y) 96 # only table a can be push down 97 pushdown_tables = set() 98 99 for a in predicates: 100 a_tables = exp.column_table_names(a) 101 102 for b in predicates: 103 a_tables &= exp.column_table_names(b) 104 105 pushdown_tables.update(a_tables) 106 107 conditions = {} 108 109 # for every pushdown table, find all related conditions in all predicates 110 # combine them with ORS 111 # (a.x AND and a.y AND b.x) OR (a.z AND c.y) -> (a.x AND a.y) OR (a.z) 112 for table in sorted(pushdown_tables): 113 for predicate in predicates: 114 nodes = nodes_for_predicate(predicate, scope, scope_ref_count) 115 116 if table not in nodes: 117 continue 118 119 predicate_condition = None 120 121 for column in predicate.find_all(exp.Column): 122 if column.table == table: 123 condition = column.find_ancestor(exp.Condition) 124 predicate_condition = ( 125 exp.and_(predicate_condition, condition) 126 if predicate_condition 127 else condition 128 ) 129 130 if predicate_condition: 131 conditions[table] = ( 132 exp.or_(conditions[table], predicate_condition) 133 if table in conditions 134 else predicate_condition 135 ) 136 137 for name, node in nodes.items(): 138 if name not in conditions: 139 continue 140 141 predicate = conditions[name] 142 143 if isinstance(node, exp.Join): 144 node.on(predicate, copy=False) 145 elif isinstance(node, exp.Select): 146 node.where(replace_aliases(node, predicate), copy=False)
If the predicates are in DNF form, we can only push down conditions that are in all blocks. Additionally, we can't remove predicates from their original form.
def
nodes_for_predicate(predicate, sources, scope_ref_count):
149def nodes_for_predicate(predicate, sources, scope_ref_count): 150 nodes = {} 151 tables = exp.column_table_names(predicate) 152 where_condition = isinstance(predicate.find_ancestor(exp.Join, exp.Where), exp.Where) 153 154 for table in sorted(tables): 155 node, source = sources.get(table) or (None, None) 156 157 # if the predicate is in a where statement we can try to push it down 158 # we want to find the root join or from statement 159 if node and where_condition: 160 node = node.find_ancestor(exp.Join, exp.From) 161 162 # a node can reference a CTE which should be pushed down 163 if isinstance(node, exp.From) and not isinstance(source, exp.Table): 164 with_ = source.parent.expression.args.get("with") 165 if with_ and with_.recursive: 166 return {} 167 node = source.expression 168 169 if isinstance(node, exp.Join): 170 if node.side and node.side != "RIGHT": 171 return {} 172 nodes[table] = node 173 elif isinstance(node, exp.Select) and len(tables) == 1: 174 # We can't push down window expressions 175 has_window_expression = any( 176 select for select in node.selects if select.find(exp.Window) 177 ) 178 # we can't push down predicates to select statements if they are referenced in 179 # multiple places. 180 if ( 181 not node.args.get("group") 182 and scope_ref_count[id(source)] < 2 183 and not has_window_expression 184 ): 185 nodes[table] = node 186 return nodes
def
replace_aliases(source, predicate):
189def replace_aliases(source, predicate): 190 aliases = {} 191 192 for select in source.selects: 193 if isinstance(select, exp.Alias): 194 aliases[select.alias] = select.this 195 else: 196 aliases[select.name] = select 197 198 def _replace_alias(column): 199 if isinstance(column, exp.Column) and column.name in aliases: 200 return aliases[column.name].copy() 201 return column 202 203 return predicate.transform(_replace_alias)