sqlglot.optimizer.pushdown_predicates
1from sqlglot import exp 2from sqlglot.optimizer.normalize import normalized 3from sqlglot.optimizer.scope import build_scope 4from sqlglot.optimizer.simplify import simplify 5 6 7def pushdown_predicates(expression): 8 """ 9 Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS 10 11 Example: 12 >>> import sqlglot 13 >>> sql = "SELECT * FROM (SELECT * FROM x AS x) AS y WHERE y.a = 1" 14 >>> expression = sqlglot.parse_one(sql) 15 >>> pushdown_predicates(expression).sql() 16 'SELECT * FROM (SELECT * FROM x AS x WHERE y.a = 1) AS y WHERE TRUE' 17 18 Args: 19 expression (sqlglot.Expression): expression to optimize 20 Returns: 21 sqlglot.Expression: optimized expression 22 """ 23 root = build_scope(expression) 24 25 if root: 26 scope_ref_count = root.ref_count() 27 28 for scope in reversed(list(root.traverse())): 29 select = scope.expression 30 where = select.args.get("where") 31 if where: 32 selected_sources = scope.selected_sources 33 # a right join can only push down to itself and not the source FROM table 34 for k, (node, source) in selected_sources.items(): 35 parent = node.find_ancestor(exp.Join, exp.From) 36 if isinstance(parent, exp.Join) and parent.side == "RIGHT": 37 selected_sources = {k: (node, source)} 38 break 39 pushdown(where.this, selected_sources, scope_ref_count) 40 41 # joins should only pushdown into itself, not to other joins 42 # so we limit the selected sources to only itself 43 for join in select.args.get("joins") or []: 44 name = join.this.alias_or_name 45 pushdown(join.args.get("on"), {name: scope.selected_sources[name]}, scope_ref_count) 46 47 return expression 48 49 50def pushdown(condition, sources, scope_ref_count): 51 if not condition: 52 return 53 54 condition = condition.replace(simplify(condition)) 55 cnf_like = normalized(condition) or not normalized(condition, dnf=True) 56 57 predicates = list( 58 condition.flatten() 59 if isinstance(condition, exp.And if cnf_like else exp.Or) 60 else [condition] 61 ) 62 63 if cnf_like: 64 pushdown_cnf(predicates, sources, scope_ref_count) 65 else: 66 pushdown_dnf(predicates, sources, scope_ref_count) 67 68 69def pushdown_cnf(predicates, scope, scope_ref_count): 70 """ 71 If the predicates are in CNF like form, we can simply replace each block in the parent. 72 """ 73 for predicate in predicates: 74 for node in nodes_for_predicate(predicate, scope, scope_ref_count).values(): 75 if isinstance(node, exp.Join): 76 predicate.replace(exp.true()) 77 node.on(predicate, copy=False) 78 break 79 if isinstance(node, exp.Select): 80 predicate.replace(exp.true()) 81 node.where(replace_aliases(node, predicate), copy=False) 82 83 84def pushdown_dnf(predicates, scope, scope_ref_count): 85 """ 86 If the predicates are in DNF form, we can only push down conditions that are in all blocks. 87 Additionally, we can't remove predicates from their original form. 88 """ 89 # find all the tables that can be pushdown too 90 # these are tables that are referenced in all blocks of a DNF 91 # (a.x AND b.x) OR (a.y AND c.y) 92 # only table a can be push down 93 pushdown_tables = set() 94 95 for a in predicates: 96 a_tables = set(exp.column_table_names(a)) 97 98 for b in predicates: 99 a_tables &= set(exp.column_table_names(b)) 100 101 pushdown_tables.update(a_tables) 102 103 conditions = {} 104 105 # for every pushdown table, find all related conditions in all predicates 106 # combine them with ORS 107 # (a.x AND and a.y AND b.x) OR (a.z AND c.y) -> (a.x AND a.y) OR (a.z) 108 for table in sorted(pushdown_tables): 109 for predicate in predicates: 110 nodes = nodes_for_predicate(predicate, scope, scope_ref_count) 111 112 if table not in nodes: 113 continue 114 115 predicate_condition = None 116 117 for column in predicate.find_all(exp.Column): 118 if column.table == table: 119 condition = column.find_ancestor(exp.Condition) 120 predicate_condition = ( 121 exp.and_(predicate_condition, condition) 122 if predicate_condition 123 else condition 124 ) 125 126 if predicate_condition: 127 conditions[table] = ( 128 exp.or_(conditions[table], predicate_condition) 129 if table in conditions 130 else predicate_condition 131 ) 132 133 for name, node in nodes.items(): 134 if name not in conditions: 135 continue 136 137 predicate = conditions[name] 138 139 if isinstance(node, exp.Join): 140 node.on(predicate, copy=False) 141 elif isinstance(node, exp.Select): 142 node.where(replace_aliases(node, predicate), copy=False) 143 144 145def nodes_for_predicate(predicate, sources, scope_ref_count): 146 nodes = {} 147 tables = exp.column_table_names(predicate) 148 where_condition = isinstance(predicate.find_ancestor(exp.Join, exp.Where), exp.Where) 149 150 for table in tables: 151 node, source = sources.get(table) or (None, None) 152 153 # if the predicate is in a where statement we can try to push it down 154 # we want to find the root join or from statement 155 if node and where_condition: 156 node = node.find_ancestor(exp.Join, exp.From) 157 158 # a node can reference a CTE which should be pushed down 159 if isinstance(node, exp.From) and not isinstance(source, exp.Table): 160 with_ = source.parent.expression.args.get("with") 161 if with_ and with_.recursive: 162 return {} 163 node = source.expression 164 165 if isinstance(node, exp.Join): 166 if node.side and node.side != "RIGHT": 167 return {} 168 nodes[table] = node 169 elif isinstance(node, exp.Select) and len(tables) == 1: 170 # We can't push down window expressions 171 has_window_expression = any( 172 select for select in node.selects if select.find(exp.Window) 173 ) 174 # we can't push down predicates to select statements if they are referenced in 175 # multiple places. 176 if ( 177 not node.args.get("group") 178 and scope_ref_count[id(source)] < 2 179 and not has_window_expression 180 ): 181 nodes[table] = node 182 return nodes 183 184 185def replace_aliases(source, predicate): 186 aliases = {} 187 188 for select in source.selects: 189 if isinstance(select, exp.Alias): 190 aliases[select.alias] = select.this 191 else: 192 aliases[select.name] = select 193 194 def _replace_alias(column): 195 if isinstance(column, exp.Column) and column.name in aliases: 196 return aliases[column.name].copy() 197 return column 198 199 return predicate.transform(_replace_alias)
def
pushdown_predicates(expression):
8def pushdown_predicates(expression): 9 """ 10 Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS 11 12 Example: 13 >>> import sqlglot 14 >>> sql = "SELECT * FROM (SELECT * FROM x AS x) AS y WHERE y.a = 1" 15 >>> expression = sqlglot.parse_one(sql) 16 >>> pushdown_predicates(expression).sql() 17 'SELECT * FROM (SELECT * FROM x AS x WHERE y.a = 1) AS y WHERE TRUE' 18 19 Args: 20 expression (sqlglot.Expression): expression to optimize 21 Returns: 22 sqlglot.Expression: optimized expression 23 """ 24 root = build_scope(expression) 25 26 if root: 27 scope_ref_count = root.ref_count() 28 29 for scope in reversed(list(root.traverse())): 30 select = scope.expression 31 where = select.args.get("where") 32 if where: 33 selected_sources = scope.selected_sources 34 # a right join can only push down to itself and not the source FROM table 35 for k, (node, source) in selected_sources.items(): 36 parent = node.find_ancestor(exp.Join, exp.From) 37 if isinstance(parent, exp.Join) and parent.side == "RIGHT": 38 selected_sources = {k: (node, source)} 39 break 40 pushdown(where.this, selected_sources, scope_ref_count) 41 42 # joins should only pushdown into itself, not to other joins 43 # so we limit the selected sources to only itself 44 for join in select.args.get("joins") or []: 45 name = join.this.alias_or_name 46 pushdown(join.args.get("on"), {name: scope.selected_sources[name]}, scope_ref_count) 47 48 return expression
Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS
Example:
>>> import sqlglot >>> sql = "SELECT * FROM (SELECT * FROM x AS x) AS y WHERE y.a = 1" >>> expression = sqlglot.parse_one(sql) >>> pushdown_predicates(expression).sql() 'SELECT * FROM (SELECT * FROM x AS x WHERE y.a = 1) AS y WHERE TRUE'
Arguments:
- expression (sqlglot.Expression): expression to optimize
Returns:
sqlglot.Expression: optimized expression
def
pushdown(condition, sources, scope_ref_count):
51def pushdown(condition, sources, scope_ref_count): 52 if not condition: 53 return 54 55 condition = condition.replace(simplify(condition)) 56 cnf_like = normalized(condition) or not normalized(condition, dnf=True) 57 58 predicates = list( 59 condition.flatten() 60 if isinstance(condition, exp.And if cnf_like else exp.Or) 61 else [condition] 62 ) 63 64 if cnf_like: 65 pushdown_cnf(predicates, sources, scope_ref_count) 66 else: 67 pushdown_dnf(predicates, sources, scope_ref_count)
def
pushdown_cnf(predicates, scope, scope_ref_count):
70def pushdown_cnf(predicates, scope, scope_ref_count): 71 """ 72 If the predicates are in CNF like form, we can simply replace each block in the parent. 73 """ 74 for predicate in predicates: 75 for node in nodes_for_predicate(predicate, scope, scope_ref_count).values(): 76 if isinstance(node, exp.Join): 77 predicate.replace(exp.true()) 78 node.on(predicate, copy=False) 79 break 80 if isinstance(node, exp.Select): 81 predicate.replace(exp.true()) 82 node.where(replace_aliases(node, predicate), copy=False)
If the predicates are in CNF like form, we can simply replace each block in the parent.
def
pushdown_dnf(predicates, scope, scope_ref_count):
85def pushdown_dnf(predicates, scope, scope_ref_count): 86 """ 87 If the predicates are in DNF form, we can only push down conditions that are in all blocks. 88 Additionally, we can't remove predicates from their original form. 89 """ 90 # find all the tables that can be pushdown too 91 # these are tables that are referenced in all blocks of a DNF 92 # (a.x AND b.x) OR (a.y AND c.y) 93 # only table a can be push down 94 pushdown_tables = set() 95 96 for a in predicates: 97 a_tables = set(exp.column_table_names(a)) 98 99 for b in predicates: 100 a_tables &= set(exp.column_table_names(b)) 101 102 pushdown_tables.update(a_tables) 103 104 conditions = {} 105 106 # for every pushdown table, find all related conditions in all predicates 107 # combine them with ORS 108 # (a.x AND and a.y AND b.x) OR (a.z AND c.y) -> (a.x AND a.y) OR (a.z) 109 for table in sorted(pushdown_tables): 110 for predicate in predicates: 111 nodes = nodes_for_predicate(predicate, scope, scope_ref_count) 112 113 if table not in nodes: 114 continue 115 116 predicate_condition = None 117 118 for column in predicate.find_all(exp.Column): 119 if column.table == table: 120 condition = column.find_ancestor(exp.Condition) 121 predicate_condition = ( 122 exp.and_(predicate_condition, condition) 123 if predicate_condition 124 else condition 125 ) 126 127 if predicate_condition: 128 conditions[table] = ( 129 exp.or_(conditions[table], predicate_condition) 130 if table in conditions 131 else predicate_condition 132 ) 133 134 for name, node in nodes.items(): 135 if name not in conditions: 136 continue 137 138 predicate = conditions[name] 139 140 if isinstance(node, exp.Join): 141 node.on(predicate, copy=False) 142 elif isinstance(node, exp.Select): 143 node.where(replace_aliases(node, predicate), copy=False)
If the predicates are in DNF form, we can only push down conditions that are in all blocks. Additionally, we can't remove predicates from their original form.
def
nodes_for_predicate(predicate, sources, scope_ref_count):
146def nodes_for_predicate(predicate, sources, scope_ref_count): 147 nodes = {} 148 tables = exp.column_table_names(predicate) 149 where_condition = isinstance(predicate.find_ancestor(exp.Join, exp.Where), exp.Where) 150 151 for table in tables: 152 node, source = sources.get(table) or (None, None) 153 154 # if the predicate is in a where statement we can try to push it down 155 # we want to find the root join or from statement 156 if node and where_condition: 157 node = node.find_ancestor(exp.Join, exp.From) 158 159 # a node can reference a CTE which should be pushed down 160 if isinstance(node, exp.From) and not isinstance(source, exp.Table): 161 with_ = source.parent.expression.args.get("with") 162 if with_ and with_.recursive: 163 return {} 164 node = source.expression 165 166 if isinstance(node, exp.Join): 167 if node.side and node.side != "RIGHT": 168 return {} 169 nodes[table] = node 170 elif isinstance(node, exp.Select) and len(tables) == 1: 171 # We can't push down window expressions 172 has_window_expression = any( 173 select for select in node.selects if select.find(exp.Window) 174 ) 175 # we can't push down predicates to select statements if they are referenced in 176 # multiple places. 177 if ( 178 not node.args.get("group") 179 and scope_ref_count[id(source)] < 2 180 and not has_window_expression 181 ): 182 nodes[table] = node 183 return nodes
def
replace_aliases(source, predicate):
186def replace_aliases(source, predicate): 187 aliases = {} 188 189 for select in source.selects: 190 if isinstance(select, exp.Alias): 191 aliases[select.alias] = select.this 192 else: 193 aliases[select.name] = select 194 195 def _replace_alias(column): 196 if isinstance(column, exp.Column) and column.name in aliases: 197 return aliases[column.name].copy() 198 return column 199 200 return predicate.transform(_replace_alias)