sqlglot.optimizer.unnest_subqueries
1from sqlglot import exp 2from sqlglot.helper import name_sequence 3from sqlglot.optimizer.scope import ScopeType, traverse_scope 4 5 6def unnest_subqueries(expression): 7 """ 8 Rewrite sqlglot AST to convert some predicates with subqueries into joins. 9 10 Convert scalar subqueries into cross joins. 11 Convert correlated or vectorized subqueries into a group by so it is not a many to many left join. 12 13 Example: 14 >>> import sqlglot 15 >>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ") 16 >>> unnest_subqueries(expression).sql() 17 'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE _u_0.a = 1' 18 19 Args: 20 expression (sqlglot.Expression): expression to unnest 21 Returns: 22 sqlglot.Expression: unnested expression 23 """ 24 next_alias_name = name_sequence("_u_") 25 26 for scope in traverse_scope(expression): 27 select = scope.expression 28 parent = select.parent_select 29 if not parent: 30 continue 31 if scope.external_columns: 32 decorrelate(select, parent, scope.external_columns, next_alias_name) 33 elif scope.scope_type == ScopeType.SUBQUERY: 34 unnest(select, parent, next_alias_name) 35 36 return expression 37 38 39def unnest(select, parent_select, next_alias_name): 40 if len(select.selects) > 1: 41 return 42 43 predicate = select.find_ancestor(exp.Condition) 44 alias = next_alias_name() 45 46 if not predicate or parent_select is not predicate.parent_select: 47 return 48 49 # This subquery returns a scalar and can just be converted to a cross join 50 if not isinstance(predicate, (exp.In, exp.Any)): 51 column = exp.column(select.selects[0].alias_or_name, alias) 52 53 clause = predicate.find_ancestor(exp.Having, exp.Where, exp.Join) 54 clause_parent_select = clause.parent_select if clause else None 55 56 if (isinstance(clause, exp.Having) and clause_parent_select is parent_select) or ( 57 (not clause or clause_parent_select is not parent_select) 58 and ( 59 parent_select.args.get("group") 60 or any(projection.find(exp.AggFunc) for projection in parent_select.selects) 61 ) 62 ): 63 column = exp.Max(this=column) 64 65 _replace(select.parent, column) 66 parent_select.join(select, join_type="CROSS", join_alias=alias, copy=False) 67 return 68 69 if select.find(exp.Limit, exp.Offset): 70 return 71 72 if isinstance(predicate, exp.Any): 73 predicate = predicate.find_ancestor(exp.EQ) 74 75 if not predicate or parent_select is not predicate.parent_select: 76 return 77 78 column = _other_operand(predicate) 79 value = select.selects[0] 80 81 on = exp.condition(f'{column} = "{alias}"."{value.alias}"') 82 _replace(predicate, f"NOT {on.right} IS NULL") 83 84 parent_select.join( 85 select.group_by(value.this, copy=False), 86 on=on, 87 join_type="LEFT", 88 join_alias=alias, 89 copy=False, 90 ) 91 92 93def decorrelate(select, parent_select, external_columns, next_alias_name): 94 where = select.args.get("where") 95 96 if not where or where.find(exp.Or) or select.find(exp.Limit, exp.Offset): 97 return 98 99 table_alias = next_alias_name() 100 keys = [] 101 102 # for all external columns in the where statement, find the relevant predicate 103 # keys to convert it into a join 104 for column in external_columns: 105 if column.find_ancestor(exp.Where) is not where: 106 return 107 108 predicate = column.find_ancestor(exp.Predicate) 109 110 if not predicate or predicate.find_ancestor(exp.Where) is not where: 111 return 112 113 if isinstance(predicate, exp.Binary): 114 key = ( 115 predicate.right 116 if any(node is column for node, *_ in predicate.left.walk()) 117 else predicate.left 118 ) 119 else: 120 return 121 122 keys.append((key, column, predicate)) 123 124 if not any(isinstance(predicate, exp.EQ) for *_, predicate in keys): 125 return 126 127 is_subquery_projection = any( 128 node is select.parent for node in parent_select.selects if isinstance(node, exp.Subquery) 129 ) 130 131 value = select.selects[0] 132 key_aliases = {} 133 group_by = [] 134 135 for key, _, predicate in keys: 136 # if we filter on the value of the subquery, it needs to be unique 137 if key == value.this: 138 key_aliases[key] = value.alias 139 group_by.append(key) 140 else: 141 if key not in key_aliases: 142 key_aliases[key] = next_alias_name() 143 # all predicates that are equalities must also be in the unique 144 # so that we don't do a many to many join 145 if isinstance(predicate, exp.EQ) and key not in group_by: 146 group_by.append(key) 147 148 parent_predicate = select.find_ancestor(exp.Predicate) 149 150 # if the value of the subquery is not an agg or a key, we need to collect it into an array 151 # so that it can be grouped. For subquery projections, we use a MAX aggregation instead. 152 agg_func = exp.Max if is_subquery_projection else exp.ArrayAgg 153 if not value.find(exp.AggFunc) and value.this not in group_by: 154 select.select( 155 exp.alias_(agg_func(this=value.this), value.alias, quoted=False), 156 append=False, 157 copy=False, 158 ) 159 160 # exists queries should not have any selects as it only checks if there are any rows 161 # all selects will be added by the optimizer and only used for join keys 162 if isinstance(parent_predicate, exp.Exists): 163 select.args["expressions"] = [] 164 165 for key, alias in key_aliases.items(): 166 if key in group_by: 167 # add all keys to the projections of the subquery 168 # so that we can use it as a join key 169 if isinstance(parent_predicate, exp.Exists) or key != value.this: 170 select.select(f"{key} AS {alias}", copy=False) 171 else: 172 select.select(exp.alias_(agg_func(this=key.copy()), alias, quoted=False), copy=False) 173 174 alias = exp.column(value.alias, table_alias) 175 other = _other_operand(parent_predicate) 176 177 if isinstance(parent_predicate, exp.Exists): 178 alias = exp.column(list(key_aliases.values())[0], table_alias) 179 parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL") 180 elif isinstance(parent_predicate, exp.All): 181 parent_predicate = _replace( 182 parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})" 183 ) 184 elif isinstance(parent_predicate, exp.Any): 185 if value.this in group_by: 186 parent_predicate = _replace(parent_predicate.parent, f"{other} = {alias}") 187 else: 188 parent_predicate = _replace(parent_predicate, f"ARRAY_ANY({alias}, _x -> _x = {other})") 189 elif isinstance(parent_predicate, exp.In): 190 if value.this in group_by: 191 parent_predicate = _replace(parent_predicate, f"{other} = {alias}") 192 else: 193 parent_predicate = _replace( 194 parent_predicate, 195 f"ARRAY_ANY({alias}, _x -> _x = {parent_predicate.this})", 196 ) 197 else: 198 if is_subquery_projection: 199 alias = exp.alias_(alias, select.parent.alias) 200 201 # COUNT always returns 0 on empty datasets, so we need take that into consideration here 202 # by transforming all counts into 0 and using that as the coalesced value 203 if value.find(exp.Count): 204 205 def remove_aggs(node): 206 if isinstance(node, exp.Count): 207 return exp.Literal.number(0) 208 elif isinstance(node, exp.AggFunc): 209 return exp.null() 210 return node 211 212 alias = exp.Coalesce( 213 this=alias, 214 expressions=[value.this.transform(remove_aggs)], 215 ) 216 217 select.parent.replace(alias) 218 219 for key, column, predicate in keys: 220 predicate.replace(exp.true()) 221 nested = exp.column(key_aliases[key], table_alias) 222 223 if is_subquery_projection: 224 key.replace(nested) 225 continue 226 227 if key in group_by: 228 key.replace(nested) 229 elif isinstance(predicate, exp.EQ): 230 parent_predicate = _replace( 231 parent_predicate, 232 f"({parent_predicate} AND ARRAY_CONTAINS({nested}, {column}))", 233 ) 234 else: 235 key.replace(exp.to_identifier("_x")) 236 parent_predicate = _replace( 237 parent_predicate, 238 f'({parent_predicate} AND ARRAY_ANY({nested}, "_x" -> {predicate}))', 239 ) 240 241 parent_select.join( 242 select.group_by(*group_by, copy=False), 243 on=[predicate for *_, predicate in keys if isinstance(predicate, exp.EQ)], 244 join_type="LEFT", 245 join_alias=table_alias, 246 copy=False, 247 ) 248 249 250def _replace(expression, condition): 251 return expression.replace(exp.condition(condition)) 252 253 254def _other_operand(expression): 255 if isinstance(expression, exp.In): 256 return expression.this 257 258 if isinstance(expression, (exp.Any, exp.All)): 259 return _other_operand(expression.parent) 260 261 if isinstance(expression, exp.Binary): 262 return ( 263 expression.right 264 if isinstance(expression.left, (exp.Subquery, exp.Any, exp.Exists, exp.All)) 265 else expression.left 266 ) 267 268 return None
def
unnest_subqueries(expression):
7def unnest_subqueries(expression): 8 """ 9 Rewrite sqlglot AST to convert some predicates with subqueries into joins. 10 11 Convert scalar subqueries into cross joins. 12 Convert correlated or vectorized subqueries into a group by so it is not a many to many left join. 13 14 Example: 15 >>> import sqlglot 16 >>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ") 17 >>> unnest_subqueries(expression).sql() 18 'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE _u_0.a = 1' 19 20 Args: 21 expression (sqlglot.Expression): expression to unnest 22 Returns: 23 sqlglot.Expression: unnested expression 24 """ 25 next_alias_name = name_sequence("_u_") 26 27 for scope in traverse_scope(expression): 28 select = scope.expression 29 parent = select.parent_select 30 if not parent: 31 continue 32 if scope.external_columns: 33 decorrelate(select, parent, scope.external_columns, next_alias_name) 34 elif scope.scope_type == ScopeType.SUBQUERY: 35 unnest(select, parent, next_alias_name) 36 37 return expression
Rewrite sqlglot AST to convert some predicates with subqueries into joins.
Convert scalar subqueries into cross joins. Convert correlated or vectorized subqueries into a group by so it is not a many to many left join.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ") >>> unnest_subqueries(expression).sql() 'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE _u_0.a = 1'
Arguments:
- expression (sqlglot.Expression): expression to unnest
Returns:
sqlglot.Expression: unnested expression
def
unnest(select, parent_select, next_alias_name):
40def unnest(select, parent_select, next_alias_name): 41 if len(select.selects) > 1: 42 return 43 44 predicate = select.find_ancestor(exp.Condition) 45 alias = next_alias_name() 46 47 if not predicate or parent_select is not predicate.parent_select: 48 return 49 50 # This subquery returns a scalar and can just be converted to a cross join 51 if not isinstance(predicate, (exp.In, exp.Any)): 52 column = exp.column(select.selects[0].alias_or_name, alias) 53 54 clause = predicate.find_ancestor(exp.Having, exp.Where, exp.Join) 55 clause_parent_select = clause.parent_select if clause else None 56 57 if (isinstance(clause, exp.Having) and clause_parent_select is parent_select) or ( 58 (not clause or clause_parent_select is not parent_select) 59 and ( 60 parent_select.args.get("group") 61 or any(projection.find(exp.AggFunc) for projection in parent_select.selects) 62 ) 63 ): 64 column = exp.Max(this=column) 65 66 _replace(select.parent, column) 67 parent_select.join(select, join_type="CROSS", join_alias=alias, copy=False) 68 return 69 70 if select.find(exp.Limit, exp.Offset): 71 return 72 73 if isinstance(predicate, exp.Any): 74 predicate = predicate.find_ancestor(exp.EQ) 75 76 if not predicate or parent_select is not predicate.parent_select: 77 return 78 79 column = _other_operand(predicate) 80 value = select.selects[0] 81 82 on = exp.condition(f'{column} = "{alias}"."{value.alias}"') 83 _replace(predicate, f"NOT {on.right} IS NULL") 84 85 parent_select.join( 86 select.group_by(value.this, copy=False), 87 on=on, 88 join_type="LEFT", 89 join_alias=alias, 90 copy=False, 91 )
def
decorrelate(select, parent_select, external_columns, next_alias_name):
94def decorrelate(select, parent_select, external_columns, next_alias_name): 95 where = select.args.get("where") 96 97 if not where or where.find(exp.Or) or select.find(exp.Limit, exp.Offset): 98 return 99 100 table_alias = next_alias_name() 101 keys = [] 102 103 # for all external columns in the where statement, find the relevant predicate 104 # keys to convert it into a join 105 for column in external_columns: 106 if column.find_ancestor(exp.Where) is not where: 107 return 108 109 predicate = column.find_ancestor(exp.Predicate) 110 111 if not predicate or predicate.find_ancestor(exp.Where) is not where: 112 return 113 114 if isinstance(predicate, exp.Binary): 115 key = ( 116 predicate.right 117 if any(node is column for node, *_ in predicate.left.walk()) 118 else predicate.left 119 ) 120 else: 121 return 122 123 keys.append((key, column, predicate)) 124 125 if not any(isinstance(predicate, exp.EQ) for *_, predicate in keys): 126 return 127 128 is_subquery_projection = any( 129 node is select.parent for node in parent_select.selects if isinstance(node, exp.Subquery) 130 ) 131 132 value = select.selects[0] 133 key_aliases = {} 134 group_by = [] 135 136 for key, _, predicate in keys: 137 # if we filter on the value of the subquery, it needs to be unique 138 if key == value.this: 139 key_aliases[key] = value.alias 140 group_by.append(key) 141 else: 142 if key not in key_aliases: 143 key_aliases[key] = next_alias_name() 144 # all predicates that are equalities must also be in the unique 145 # so that we don't do a many to many join 146 if isinstance(predicate, exp.EQ) and key not in group_by: 147 group_by.append(key) 148 149 parent_predicate = select.find_ancestor(exp.Predicate) 150 151 # if the value of the subquery is not an agg or a key, we need to collect it into an array 152 # so that it can be grouped. For subquery projections, we use a MAX aggregation instead. 153 agg_func = exp.Max if is_subquery_projection else exp.ArrayAgg 154 if not value.find(exp.AggFunc) and value.this not in group_by: 155 select.select( 156 exp.alias_(agg_func(this=value.this), value.alias, quoted=False), 157 append=False, 158 copy=False, 159 ) 160 161 # exists queries should not have any selects as it only checks if there are any rows 162 # all selects will be added by the optimizer and only used for join keys 163 if isinstance(parent_predicate, exp.Exists): 164 select.args["expressions"] = [] 165 166 for key, alias in key_aliases.items(): 167 if key in group_by: 168 # add all keys to the projections of the subquery 169 # so that we can use it as a join key 170 if isinstance(parent_predicate, exp.Exists) or key != value.this: 171 select.select(f"{key} AS {alias}", copy=False) 172 else: 173 select.select(exp.alias_(agg_func(this=key.copy()), alias, quoted=False), copy=False) 174 175 alias = exp.column(value.alias, table_alias) 176 other = _other_operand(parent_predicate) 177 178 if isinstance(parent_predicate, exp.Exists): 179 alias = exp.column(list(key_aliases.values())[0], table_alias) 180 parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL") 181 elif isinstance(parent_predicate, exp.All): 182 parent_predicate = _replace( 183 parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})" 184 ) 185 elif isinstance(parent_predicate, exp.Any): 186 if value.this in group_by: 187 parent_predicate = _replace(parent_predicate.parent, f"{other} = {alias}") 188 else: 189 parent_predicate = _replace(parent_predicate, f"ARRAY_ANY({alias}, _x -> _x = {other})") 190 elif isinstance(parent_predicate, exp.In): 191 if value.this in group_by: 192 parent_predicate = _replace(parent_predicate, f"{other} = {alias}") 193 else: 194 parent_predicate = _replace( 195 parent_predicate, 196 f"ARRAY_ANY({alias}, _x -> _x = {parent_predicate.this})", 197 ) 198 else: 199 if is_subquery_projection: 200 alias = exp.alias_(alias, select.parent.alias) 201 202 # COUNT always returns 0 on empty datasets, so we need take that into consideration here 203 # by transforming all counts into 0 and using that as the coalesced value 204 if value.find(exp.Count): 205 206 def remove_aggs(node): 207 if isinstance(node, exp.Count): 208 return exp.Literal.number(0) 209 elif isinstance(node, exp.AggFunc): 210 return exp.null() 211 return node 212 213 alias = exp.Coalesce( 214 this=alias, 215 expressions=[value.this.transform(remove_aggs)], 216 ) 217 218 select.parent.replace(alias) 219 220 for key, column, predicate in keys: 221 predicate.replace(exp.true()) 222 nested = exp.column(key_aliases[key], table_alias) 223 224 if is_subquery_projection: 225 key.replace(nested) 226 continue 227 228 if key in group_by: 229 key.replace(nested) 230 elif isinstance(predicate, exp.EQ): 231 parent_predicate = _replace( 232 parent_predicate, 233 f"({parent_predicate} AND ARRAY_CONTAINS({nested}, {column}))", 234 ) 235 else: 236 key.replace(exp.to_identifier("_x")) 237 parent_predicate = _replace( 238 parent_predicate, 239 f'({parent_predicate} AND ARRAY_ANY({nested}, "_x" -> {predicate}))', 240 ) 241 242 parent_select.join( 243 select.group_by(*group_by, copy=False), 244 on=[predicate for *_, predicate in keys if isinstance(predicate, exp.EQ)], 245 join_type="LEFT", 246 join_alias=table_alias, 247 copy=False, 248 )