sqlglot.optimizer.unnest_subqueries
1from sqlglot import exp 2from sqlglot.helper import name_sequence 3from sqlglot.optimizer.scope import ScopeType, traverse_scope 4 5 6def unnest_subqueries(expression): 7 """ 8 Rewrite sqlglot AST to convert some predicates with subqueries into joins. 9 10 Convert scalar subqueries into cross joins. 11 Convert correlated or vectorized subqueries into a group by so it is not a many to many left join. 12 13 Example: 14 >>> import sqlglot 15 >>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ") 16 >>> unnest_subqueries(expression).sql() 17 'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE _u_0.a = 1' 18 19 Args: 20 expression (sqlglot.Expression): expression to unnest 21 Returns: 22 sqlglot.Expression: unnested expression 23 """ 24 next_alias_name = name_sequence("_u_") 25 26 for scope in traverse_scope(expression): 27 select = scope.expression 28 parent = select.parent_select 29 if not parent: 30 continue 31 if scope.external_columns: 32 decorrelate(select, parent, scope.external_columns, next_alias_name) 33 elif scope.scope_type == ScopeType.SUBQUERY: 34 unnest(select, parent, next_alias_name) 35 36 return expression 37 38 39def unnest(select, parent_select, next_alias_name): 40 if len(select.selects) > 1: 41 return 42 43 predicate = select.find_ancestor(exp.Condition) 44 alias = next_alias_name() 45 46 if not predicate or parent_select is not predicate.parent_select: 47 return 48 49 # this subquery returns a scalar and can just be converted to a cross join 50 if not isinstance(predicate, (exp.In, exp.Any)): 51 having = predicate.find_ancestor(exp.Having) 52 column = exp.column(select.selects[0].alias_or_name, alias) 53 if having and having.parent_select is parent_select: 54 column = exp.Max(this=column) 55 _replace(select.parent, column) 56 57 parent_select.join( 58 select, 59 join_type="CROSS", 60 join_alias=alias, 61 copy=False, 62 ) 63 return 64 65 if select.find(exp.Limit, exp.Offset): 66 return 67 68 if isinstance(predicate, exp.Any): 69 predicate = predicate.find_ancestor(exp.EQ) 70 71 if not predicate or parent_select is not predicate.parent_select: 72 return 73 74 column = _other_operand(predicate) 75 value = select.selects[0] 76 77 on = exp.condition(f'{column} = "{alias}"."{value.alias}"') 78 _replace(predicate, f"NOT {on.right} IS NULL") 79 80 parent_select.join( 81 select.group_by(value.this, copy=False), 82 on=on, 83 join_type="LEFT", 84 join_alias=alias, 85 copy=False, 86 ) 87 88 89def decorrelate(select, parent_select, external_columns, next_alias_name): 90 where = select.args.get("where") 91 92 if not where or where.find(exp.Or) or select.find(exp.Limit, exp.Offset): 93 return 94 95 table_alias = next_alias_name() 96 keys = [] 97 98 # for all external columns in the where statement, find the relevant predicate 99 # keys to convert it into a join 100 for column in external_columns: 101 if column.find_ancestor(exp.Where) is not where: 102 return 103 104 predicate = column.find_ancestor(exp.Predicate) 105 106 if not predicate or predicate.find_ancestor(exp.Where) is not where: 107 return 108 109 if isinstance(predicate, exp.Binary): 110 key = ( 111 predicate.right 112 if any(node is column for node, *_ in predicate.left.walk()) 113 else predicate.left 114 ) 115 else: 116 return 117 118 keys.append((key, column, predicate)) 119 120 if not any(isinstance(predicate, exp.EQ) for *_, predicate in keys): 121 return 122 123 is_subquery_projection = any( 124 node is select.parent for node in parent_select.selects if isinstance(node, exp.Subquery) 125 ) 126 127 value = select.selects[0] 128 key_aliases = {} 129 group_by = [] 130 131 for key, _, predicate in keys: 132 # if we filter on the value of the subquery, it needs to be unique 133 if key == value.this: 134 key_aliases[key] = value.alias 135 group_by.append(key) 136 else: 137 if key not in key_aliases: 138 key_aliases[key] = next_alias_name() 139 # all predicates that are equalities must also be in the unique 140 # so that we don't do a many to many join 141 if isinstance(predicate, exp.EQ) and key not in group_by: 142 group_by.append(key) 143 144 parent_predicate = select.find_ancestor(exp.Predicate) 145 146 # if the value of the subquery is not an agg or a key, we need to collect it into an array 147 # so that it can be grouped. For subquery projections, we use a MAX aggregation instead. 148 agg_func = exp.Max if is_subquery_projection else exp.ArrayAgg 149 if not value.find(exp.AggFunc) and value.this not in group_by: 150 select.select( 151 exp.alias_(agg_func(this=value.this), value.alias, quoted=False), 152 append=False, 153 copy=False, 154 ) 155 156 # exists queries should not have any selects as it only checks if there are any rows 157 # all selects will be added by the optimizer and only used for join keys 158 if isinstance(parent_predicate, exp.Exists): 159 select.args["expressions"] = [] 160 161 for key, alias in key_aliases.items(): 162 if key in group_by: 163 # add all keys to the projections of the subquery 164 # so that we can use it as a join key 165 if isinstance(parent_predicate, exp.Exists) or key != value.this: 166 select.select(f"{key} AS {alias}", copy=False) 167 else: 168 select.select(exp.alias_(agg_func(this=key.copy()), alias, quoted=False), copy=False) 169 170 alias = exp.column(value.alias, table_alias) 171 other = _other_operand(parent_predicate) 172 173 if isinstance(parent_predicate, exp.Exists): 174 alias = exp.column(list(key_aliases.values())[0], table_alias) 175 parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL") 176 elif isinstance(parent_predicate, exp.All): 177 parent_predicate = _replace( 178 parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})" 179 ) 180 elif isinstance(parent_predicate, exp.Any): 181 if value.this in group_by: 182 parent_predicate = _replace(parent_predicate.parent, f"{other} = {alias}") 183 else: 184 parent_predicate = _replace(parent_predicate, f"ARRAY_ANY({alias}, _x -> _x = {other})") 185 elif isinstance(parent_predicate, exp.In): 186 if value.this in group_by: 187 parent_predicate = _replace(parent_predicate, f"{other} = {alias}") 188 else: 189 parent_predicate = _replace( 190 parent_predicate, 191 f"ARRAY_ANY({alias}, _x -> _x = {parent_predicate.this})", 192 ) 193 else: 194 if is_subquery_projection: 195 alias = exp.alias_(alias, select.parent.alias) 196 197 # COUNT always returns 0 on empty datasets, so we need take that into consideration here 198 # by transforming all counts into 0 and using that as the coalesced value 199 if value.find(exp.Count): 200 201 def remove_aggs(node): 202 if isinstance(node, exp.Count): 203 return exp.Literal.number(0) 204 elif isinstance(node, exp.AggFunc): 205 return exp.null() 206 return node 207 208 alias = exp.Coalesce( 209 this=alias, 210 expressions=[value.this.transform(remove_aggs)], 211 ) 212 213 select.parent.replace(alias) 214 215 for key, column, predicate in keys: 216 predicate.replace(exp.true()) 217 nested = exp.column(key_aliases[key], table_alias) 218 219 if is_subquery_projection: 220 key.replace(nested) 221 continue 222 223 if key in group_by: 224 key.replace(nested) 225 elif isinstance(predicate, exp.EQ): 226 parent_predicate = _replace( 227 parent_predicate, 228 f"({parent_predicate} AND ARRAY_CONTAINS({nested}, {column}))", 229 ) 230 else: 231 key.replace(exp.to_identifier("_x")) 232 parent_predicate = _replace( 233 parent_predicate, 234 f'({parent_predicate} AND ARRAY_ANY({nested}, "_x" -> {predicate}))', 235 ) 236 237 parent_select.join( 238 select.group_by(*group_by, copy=False), 239 on=[predicate for *_, predicate in keys if isinstance(predicate, exp.EQ)], 240 join_type="LEFT", 241 join_alias=table_alias, 242 copy=False, 243 ) 244 245 246def _replace(expression, condition): 247 return expression.replace(exp.condition(condition)) 248 249 250def _other_operand(expression): 251 if isinstance(expression, exp.In): 252 return expression.this 253 254 if isinstance(expression, (exp.Any, exp.All)): 255 return _other_operand(expression.parent) 256 257 if isinstance(expression, exp.Binary): 258 return ( 259 expression.right 260 if isinstance(expression.left, (exp.Subquery, exp.Any, exp.Exists, exp.All)) 261 else expression.left 262 ) 263 264 return None
def
unnest_subqueries(expression):
7def unnest_subqueries(expression): 8 """ 9 Rewrite sqlglot AST to convert some predicates with subqueries into joins. 10 11 Convert scalar subqueries into cross joins. 12 Convert correlated or vectorized subqueries into a group by so it is not a many to many left join. 13 14 Example: 15 >>> import sqlglot 16 >>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ") 17 >>> unnest_subqueries(expression).sql() 18 'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE _u_0.a = 1' 19 20 Args: 21 expression (sqlglot.Expression): expression to unnest 22 Returns: 23 sqlglot.Expression: unnested expression 24 """ 25 next_alias_name = name_sequence("_u_") 26 27 for scope in traverse_scope(expression): 28 select = scope.expression 29 parent = select.parent_select 30 if not parent: 31 continue 32 if scope.external_columns: 33 decorrelate(select, parent, scope.external_columns, next_alias_name) 34 elif scope.scope_type == ScopeType.SUBQUERY: 35 unnest(select, parent, next_alias_name) 36 37 return expression
Rewrite sqlglot AST to convert some predicates with subqueries into joins.
Convert scalar subqueries into cross joins. Convert correlated or vectorized subqueries into a group by so it is not a many to many left join.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ") >>> unnest_subqueries(expression).sql() 'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE _u_0.a = 1'
Arguments:
- expression (sqlglot.Expression): expression to unnest
Returns:
sqlglot.Expression: unnested expression
def
unnest(select, parent_select, next_alias_name):
40def unnest(select, parent_select, next_alias_name): 41 if len(select.selects) > 1: 42 return 43 44 predicate = select.find_ancestor(exp.Condition) 45 alias = next_alias_name() 46 47 if not predicate or parent_select is not predicate.parent_select: 48 return 49 50 # this subquery returns a scalar and can just be converted to a cross join 51 if not isinstance(predicate, (exp.In, exp.Any)): 52 having = predicate.find_ancestor(exp.Having) 53 column = exp.column(select.selects[0].alias_or_name, alias) 54 if having and having.parent_select is parent_select: 55 column = exp.Max(this=column) 56 _replace(select.parent, column) 57 58 parent_select.join( 59 select, 60 join_type="CROSS", 61 join_alias=alias, 62 copy=False, 63 ) 64 return 65 66 if select.find(exp.Limit, exp.Offset): 67 return 68 69 if isinstance(predicate, exp.Any): 70 predicate = predicate.find_ancestor(exp.EQ) 71 72 if not predicate or parent_select is not predicate.parent_select: 73 return 74 75 column = _other_operand(predicate) 76 value = select.selects[0] 77 78 on = exp.condition(f'{column} = "{alias}"."{value.alias}"') 79 _replace(predicate, f"NOT {on.right} IS NULL") 80 81 parent_select.join( 82 select.group_by(value.this, copy=False), 83 on=on, 84 join_type="LEFT", 85 join_alias=alias, 86 copy=False, 87 )
def
decorrelate(select, parent_select, external_columns, next_alias_name):
90def decorrelate(select, parent_select, external_columns, next_alias_name): 91 where = select.args.get("where") 92 93 if not where or where.find(exp.Or) or select.find(exp.Limit, exp.Offset): 94 return 95 96 table_alias = next_alias_name() 97 keys = [] 98 99 # for all external columns in the where statement, find the relevant predicate 100 # keys to convert it into a join 101 for column in external_columns: 102 if column.find_ancestor(exp.Where) is not where: 103 return 104 105 predicate = column.find_ancestor(exp.Predicate) 106 107 if not predicate or predicate.find_ancestor(exp.Where) is not where: 108 return 109 110 if isinstance(predicate, exp.Binary): 111 key = ( 112 predicate.right 113 if any(node is column for node, *_ in predicate.left.walk()) 114 else predicate.left 115 ) 116 else: 117 return 118 119 keys.append((key, column, predicate)) 120 121 if not any(isinstance(predicate, exp.EQ) for *_, predicate in keys): 122 return 123 124 is_subquery_projection = any( 125 node is select.parent for node in parent_select.selects if isinstance(node, exp.Subquery) 126 ) 127 128 value = select.selects[0] 129 key_aliases = {} 130 group_by = [] 131 132 for key, _, predicate in keys: 133 # if we filter on the value of the subquery, it needs to be unique 134 if key == value.this: 135 key_aliases[key] = value.alias 136 group_by.append(key) 137 else: 138 if key not in key_aliases: 139 key_aliases[key] = next_alias_name() 140 # all predicates that are equalities must also be in the unique 141 # so that we don't do a many to many join 142 if isinstance(predicate, exp.EQ) and key not in group_by: 143 group_by.append(key) 144 145 parent_predicate = select.find_ancestor(exp.Predicate) 146 147 # if the value of the subquery is not an agg or a key, we need to collect it into an array 148 # so that it can be grouped. For subquery projections, we use a MAX aggregation instead. 149 agg_func = exp.Max if is_subquery_projection else exp.ArrayAgg 150 if not value.find(exp.AggFunc) and value.this not in group_by: 151 select.select( 152 exp.alias_(agg_func(this=value.this), value.alias, quoted=False), 153 append=False, 154 copy=False, 155 ) 156 157 # exists queries should not have any selects as it only checks if there are any rows 158 # all selects will be added by the optimizer and only used for join keys 159 if isinstance(parent_predicate, exp.Exists): 160 select.args["expressions"] = [] 161 162 for key, alias in key_aliases.items(): 163 if key in group_by: 164 # add all keys to the projections of the subquery 165 # so that we can use it as a join key 166 if isinstance(parent_predicate, exp.Exists) or key != value.this: 167 select.select(f"{key} AS {alias}", copy=False) 168 else: 169 select.select(exp.alias_(agg_func(this=key.copy()), alias, quoted=False), copy=False) 170 171 alias = exp.column(value.alias, table_alias) 172 other = _other_operand(parent_predicate) 173 174 if isinstance(parent_predicate, exp.Exists): 175 alias = exp.column(list(key_aliases.values())[0], table_alias) 176 parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL") 177 elif isinstance(parent_predicate, exp.All): 178 parent_predicate = _replace( 179 parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})" 180 ) 181 elif isinstance(parent_predicate, exp.Any): 182 if value.this in group_by: 183 parent_predicate = _replace(parent_predicate.parent, f"{other} = {alias}") 184 else: 185 parent_predicate = _replace(parent_predicate, f"ARRAY_ANY({alias}, _x -> _x = {other})") 186 elif isinstance(parent_predicate, exp.In): 187 if value.this in group_by: 188 parent_predicate = _replace(parent_predicate, f"{other} = {alias}") 189 else: 190 parent_predicate = _replace( 191 parent_predicate, 192 f"ARRAY_ANY({alias}, _x -> _x = {parent_predicate.this})", 193 ) 194 else: 195 if is_subquery_projection: 196 alias = exp.alias_(alias, select.parent.alias) 197 198 # COUNT always returns 0 on empty datasets, so we need take that into consideration here 199 # by transforming all counts into 0 and using that as the coalesced value 200 if value.find(exp.Count): 201 202 def remove_aggs(node): 203 if isinstance(node, exp.Count): 204 return exp.Literal.number(0) 205 elif isinstance(node, exp.AggFunc): 206 return exp.null() 207 return node 208 209 alias = exp.Coalesce( 210 this=alias, 211 expressions=[value.this.transform(remove_aggs)], 212 ) 213 214 select.parent.replace(alias) 215 216 for key, column, predicate in keys: 217 predicate.replace(exp.true()) 218 nested = exp.column(key_aliases[key], table_alias) 219 220 if is_subquery_projection: 221 key.replace(nested) 222 continue 223 224 if key in group_by: 225 key.replace(nested) 226 elif isinstance(predicate, exp.EQ): 227 parent_predicate = _replace( 228 parent_predicate, 229 f"({parent_predicate} AND ARRAY_CONTAINS({nested}, {column}))", 230 ) 231 else: 232 key.replace(exp.to_identifier("_x")) 233 parent_predicate = _replace( 234 parent_predicate, 235 f'({parent_predicate} AND ARRAY_ANY({nested}, "_x" -> {predicate}))', 236 ) 237 238 parent_select.join( 239 select.group_by(*group_by, copy=False), 240 on=[predicate for *_, predicate in keys if isinstance(predicate, exp.EQ)], 241 join_type="LEFT", 242 join_alias=table_alias, 243 copy=False, 244 )