sqlglot.optimizer.pushdown_predicates
1from sqlglot import exp 2from sqlglot.optimizer.normalize import normalized 3from sqlglot.optimizer.scope import build_scope 4from sqlglot.optimizer.simplify import simplify 5 6 7def pushdown_predicates(expression): 8 """ 9 Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS 10 11 Example: 12 >>> import sqlglot 13 >>> sql = "SELECT * FROM (SELECT * FROM x AS x) AS y WHERE y.a = 1" 14 >>> expression = sqlglot.parse_one(sql) 15 >>> pushdown_predicates(expression).sql() 16 'SELECT * FROM (SELECT * FROM x AS x WHERE y.a = 1) AS y WHERE TRUE' 17 18 Args: 19 expression (sqlglot.Expression): expression to optimize 20 Returns: 21 sqlglot.Expression: optimized expression 22 """ 23 root = build_scope(expression) 24 scope_ref_count = root.ref_count() 25 26 for scope in reversed(list(root.traverse())): 27 select = scope.expression 28 where = select.args.get("where") 29 if where: 30 selected_sources = scope.selected_sources 31 # a right join can only push down to itself and not the source FROM table 32 for k, (node, source) in selected_sources.items(): 33 parent = node.find_ancestor(exp.Join, exp.From) 34 if isinstance(parent, exp.Join) and parent.side == "RIGHT": 35 selected_sources = {k: (node, source)} 36 break 37 pushdown(where.this, selected_sources, scope_ref_count) 38 39 # joins should only pushdown into itself, not to other joins 40 # so we limit the selected sources to only itself 41 for join in select.args.get("joins") or []: 42 name = join.this.alias_or_name 43 pushdown(join.args.get("on"), {name: scope.selected_sources[name]}, scope_ref_count) 44 45 return expression 46 47 48def pushdown(condition, sources, scope_ref_count): 49 if not condition: 50 return 51 52 condition = condition.replace(simplify(condition)) 53 cnf_like = normalized(condition) or not normalized(condition, dnf=True) 54 55 predicates = list( 56 condition.flatten() 57 if isinstance(condition, exp.And if cnf_like else exp.Or) 58 else [condition] 59 ) 60 61 if cnf_like: 62 pushdown_cnf(predicates, sources, scope_ref_count) 63 else: 64 pushdown_dnf(predicates, sources, scope_ref_count) 65 66 67def pushdown_cnf(predicates, scope, scope_ref_count): 68 """ 69 If the predicates are in CNF like form, we can simply replace each block in the parent. 70 """ 71 for predicate in predicates: 72 for node in nodes_for_predicate(predicate, scope, scope_ref_count).values(): 73 if isinstance(node, exp.Join): 74 predicate.replace(exp.true()) 75 node.on(predicate, copy=False) 76 break 77 if isinstance(node, exp.Select): 78 predicate.replace(exp.true()) 79 node.where(replace_aliases(node, predicate), copy=False) 80 81 82def pushdown_dnf(predicates, scope, scope_ref_count): 83 """ 84 If the predicates are in DNF form, we can only push down conditions that are in all blocks. 85 Additionally, we can't remove predicates from their original form. 86 """ 87 # find all the tables that can be pushdown too 88 # these are tables that are referenced in all blocks of a DNF 89 # (a.x AND b.x) OR (a.y AND c.y) 90 # only table a can be push down 91 pushdown_tables = set() 92 93 for a in predicates: 94 a_tables = set(exp.column_table_names(a)) 95 96 for b in predicates: 97 a_tables &= set(exp.column_table_names(b)) 98 99 pushdown_tables.update(a_tables) 100 101 conditions = {} 102 103 # for every pushdown table, find all related conditions in all predicates 104 # combine them with ORS 105 # (a.x AND and a.y AND b.x) OR (a.z AND c.y) -> (a.x AND a.y) OR (a.z) 106 for table in sorted(pushdown_tables): 107 for predicate in predicates: 108 nodes = nodes_for_predicate(predicate, scope, scope_ref_count) 109 110 if table not in nodes: 111 continue 112 113 predicate_condition = None 114 115 for column in predicate.find_all(exp.Column): 116 if column.table == table: 117 condition = column.find_ancestor(exp.Condition) 118 predicate_condition = ( 119 exp.and_(predicate_condition, condition) 120 if predicate_condition 121 else condition 122 ) 123 124 if predicate_condition: 125 conditions[table] = ( 126 exp.or_(conditions[table], predicate_condition) 127 if table in conditions 128 else predicate_condition 129 ) 130 131 for name, node in nodes.items(): 132 if name not in conditions: 133 continue 134 135 predicate = conditions[name] 136 137 if isinstance(node, exp.Join): 138 node.on(predicate, copy=False) 139 elif isinstance(node, exp.Select): 140 node.where(replace_aliases(node, predicate), copy=False) 141 142 143def nodes_for_predicate(predicate, sources, scope_ref_count): 144 nodes = {} 145 tables = exp.column_table_names(predicate) 146 where_condition = isinstance(predicate.find_ancestor(exp.Join, exp.Where), exp.Where) 147 148 for table in tables: 149 node, source = sources.get(table) or (None, None) 150 151 # if the predicate is in a where statement we can try to push it down 152 # we want to find the root join or from statement 153 if node and where_condition: 154 node = node.find_ancestor(exp.Join, exp.From) 155 156 # a node can reference a CTE which should be pushed down 157 if isinstance(node, exp.From) and not isinstance(source, exp.Table): 158 with_ = source.parent.expression.args.get("with") 159 if with_ and with_.recursive: 160 return {} 161 node = source.expression 162 163 if isinstance(node, exp.Join): 164 if node.side and node.side != "RIGHT": 165 return {} 166 nodes[table] = node 167 elif isinstance(node, exp.Select) and len(tables) == 1: 168 # We can't push down window expressions 169 has_window_expression = any( 170 select for select in node.selects if select.find(exp.Window) 171 ) 172 # we can't push down predicates to select statements if they are referenced in 173 # multiple places. 174 if ( 175 not node.args.get("group") 176 and scope_ref_count[id(source)] < 2 177 and not has_window_expression 178 ): 179 nodes[table] = node 180 return nodes 181 182 183def replace_aliases(source, predicate): 184 aliases = {} 185 186 for select in source.selects: 187 if isinstance(select, exp.Alias): 188 aliases[select.alias] = select.this 189 else: 190 aliases[select.name] = select 191 192 def _replace_alias(column): 193 if isinstance(column, exp.Column) and column.name in aliases: 194 return aliases[column.name].copy() 195 return column 196 197 return predicate.transform(_replace_alias)
def
pushdown_predicates(expression):
8def pushdown_predicates(expression): 9 """ 10 Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS 11 12 Example: 13 >>> import sqlglot 14 >>> sql = "SELECT * FROM (SELECT * FROM x AS x) AS y WHERE y.a = 1" 15 >>> expression = sqlglot.parse_one(sql) 16 >>> pushdown_predicates(expression).sql() 17 'SELECT * FROM (SELECT * FROM x AS x WHERE y.a = 1) AS y WHERE TRUE' 18 19 Args: 20 expression (sqlglot.Expression): expression to optimize 21 Returns: 22 sqlglot.Expression: optimized expression 23 """ 24 root = build_scope(expression) 25 scope_ref_count = root.ref_count() 26 27 for scope in reversed(list(root.traverse())): 28 select = scope.expression 29 where = select.args.get("where") 30 if where: 31 selected_sources = scope.selected_sources 32 # a right join can only push down to itself and not the source FROM table 33 for k, (node, source) in selected_sources.items(): 34 parent = node.find_ancestor(exp.Join, exp.From) 35 if isinstance(parent, exp.Join) and parent.side == "RIGHT": 36 selected_sources = {k: (node, source)} 37 break 38 pushdown(where.this, selected_sources, scope_ref_count) 39 40 # joins should only pushdown into itself, not to other joins 41 # so we limit the selected sources to only itself 42 for join in select.args.get("joins") or []: 43 name = join.this.alias_or_name 44 pushdown(join.args.get("on"), {name: scope.selected_sources[name]}, scope_ref_count) 45 46 return expression
Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS
Example:
>>> import sqlglot >>> sql = "SELECT * FROM (SELECT * FROM x AS x) AS y WHERE y.a = 1" >>> expression = sqlglot.parse_one(sql) >>> pushdown_predicates(expression).sql() 'SELECT * FROM (SELECT * FROM x AS x WHERE y.a = 1) AS y WHERE TRUE'
Arguments:
- expression (sqlglot.Expression): expression to optimize
Returns:
sqlglot.Expression: optimized expression
def
pushdown(condition, sources, scope_ref_count):
49def pushdown(condition, sources, scope_ref_count): 50 if not condition: 51 return 52 53 condition = condition.replace(simplify(condition)) 54 cnf_like = normalized(condition) or not normalized(condition, dnf=True) 55 56 predicates = list( 57 condition.flatten() 58 if isinstance(condition, exp.And if cnf_like else exp.Or) 59 else [condition] 60 ) 61 62 if cnf_like: 63 pushdown_cnf(predicates, sources, scope_ref_count) 64 else: 65 pushdown_dnf(predicates, sources, scope_ref_count)
def
pushdown_cnf(predicates, scope, scope_ref_count):
68def pushdown_cnf(predicates, scope, scope_ref_count): 69 """ 70 If the predicates are in CNF like form, we can simply replace each block in the parent. 71 """ 72 for predicate in predicates: 73 for node in nodes_for_predicate(predicate, scope, scope_ref_count).values(): 74 if isinstance(node, exp.Join): 75 predicate.replace(exp.true()) 76 node.on(predicate, copy=False) 77 break 78 if isinstance(node, exp.Select): 79 predicate.replace(exp.true()) 80 node.where(replace_aliases(node, predicate), copy=False)
If the predicates are in CNF like form, we can simply replace each block in the parent.
def
pushdown_dnf(predicates, scope, scope_ref_count):
83def pushdown_dnf(predicates, scope, scope_ref_count): 84 """ 85 If the predicates are in DNF form, we can only push down conditions that are in all blocks. 86 Additionally, we can't remove predicates from their original form. 87 """ 88 # find all the tables that can be pushdown too 89 # these are tables that are referenced in all blocks of a DNF 90 # (a.x AND b.x) OR (a.y AND c.y) 91 # only table a can be push down 92 pushdown_tables = set() 93 94 for a in predicates: 95 a_tables = set(exp.column_table_names(a)) 96 97 for b in predicates: 98 a_tables &= set(exp.column_table_names(b)) 99 100 pushdown_tables.update(a_tables) 101 102 conditions = {} 103 104 # for every pushdown table, find all related conditions in all predicates 105 # combine them with ORS 106 # (a.x AND and a.y AND b.x) OR (a.z AND c.y) -> (a.x AND a.y) OR (a.z) 107 for table in sorted(pushdown_tables): 108 for predicate in predicates: 109 nodes = nodes_for_predicate(predicate, scope, scope_ref_count) 110 111 if table not in nodes: 112 continue 113 114 predicate_condition = None 115 116 for column in predicate.find_all(exp.Column): 117 if column.table == table: 118 condition = column.find_ancestor(exp.Condition) 119 predicate_condition = ( 120 exp.and_(predicate_condition, condition) 121 if predicate_condition 122 else condition 123 ) 124 125 if predicate_condition: 126 conditions[table] = ( 127 exp.or_(conditions[table], predicate_condition) 128 if table in conditions 129 else predicate_condition 130 ) 131 132 for name, node in nodes.items(): 133 if name not in conditions: 134 continue 135 136 predicate = conditions[name] 137 138 if isinstance(node, exp.Join): 139 node.on(predicate, copy=False) 140 elif isinstance(node, exp.Select): 141 node.where(replace_aliases(node, predicate), copy=False)
If the predicates are in DNF form, we can only push down conditions that are in all blocks. Additionally, we can't remove predicates from their original form.
def
nodes_for_predicate(predicate, sources, scope_ref_count):
144def nodes_for_predicate(predicate, sources, scope_ref_count): 145 nodes = {} 146 tables = exp.column_table_names(predicate) 147 where_condition = isinstance(predicate.find_ancestor(exp.Join, exp.Where), exp.Where) 148 149 for table in tables: 150 node, source = sources.get(table) or (None, None) 151 152 # if the predicate is in a where statement we can try to push it down 153 # we want to find the root join or from statement 154 if node and where_condition: 155 node = node.find_ancestor(exp.Join, exp.From) 156 157 # a node can reference a CTE which should be pushed down 158 if isinstance(node, exp.From) and not isinstance(source, exp.Table): 159 with_ = source.parent.expression.args.get("with") 160 if with_ and with_.recursive: 161 return {} 162 node = source.expression 163 164 if isinstance(node, exp.Join): 165 if node.side and node.side != "RIGHT": 166 return {} 167 nodes[table] = node 168 elif isinstance(node, exp.Select) and len(tables) == 1: 169 # We can't push down window expressions 170 has_window_expression = any( 171 select for select in node.selects if select.find(exp.Window) 172 ) 173 # we can't push down predicates to select statements if they are referenced in 174 # multiple places. 175 if ( 176 not node.args.get("group") 177 and scope_ref_count[id(source)] < 2 178 and not has_window_expression 179 ): 180 nodes[table] = node 181 return nodes
def
replace_aliases(source, predicate):
184def replace_aliases(source, predicate): 185 aliases = {} 186 187 for select in source.selects: 188 if isinstance(select, exp.Alias): 189 aliases[select.alias] = select.this 190 else: 191 aliases[select.name] = select 192 193 def _replace_alias(column): 194 if isinstance(column, exp.Column) and column.name in aliases: 195 return aliases[column.name].copy() 196 return column 197 198 return predicate.transform(_replace_alias)