sqlglot.optimizer.unnest_subqueries
1from sqlglot import exp 2from sqlglot.helper import name_sequence 3from sqlglot.optimizer.scope import ScopeType, traverse_scope 4 5 6def unnest_subqueries(expression): 7 """ 8 Rewrite sqlglot AST to convert some predicates with subqueries into joins. 9 10 Convert scalar subqueries into cross joins. 11 Convert correlated or vectorized subqueries into a group by so it is not a many to many left join. 12 13 Example: 14 >>> import sqlglot 15 >>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ") 16 >>> unnest_subqueries(expression).sql() 17 'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE _u_0.a = 1' 18 19 Args: 20 expression (sqlglot.Expression): expression to unnest 21 Returns: 22 sqlglot.Expression: unnested expression 23 """ 24 next_alias_name = name_sequence("_u_") 25 26 for scope in traverse_scope(expression): 27 select = scope.expression 28 parent = select.parent_select 29 if not parent: 30 continue 31 if scope.external_columns: 32 decorrelate(select, parent, scope.external_columns, next_alias_name) 33 elif scope.scope_type == ScopeType.SUBQUERY: 34 unnest(select, parent, next_alias_name) 35 36 return expression 37 38 39def unnest(select, parent_select, next_alias_name): 40 if len(select.selects) > 1: 41 return 42 43 predicate = select.find_ancestor(exp.Condition) 44 alias = next_alias_name() 45 46 if ( 47 not predicate 48 or parent_select is not predicate.parent_select 49 or not parent_select.args.get("from") 50 ): 51 return 52 53 clause = predicate.find_ancestor(exp.Having, exp.Where, exp.Join) 54 55 # This subquery returns a scalar and can just be converted to a cross join 56 if not isinstance(predicate, (exp.In, exp.Any)): 57 column = exp.column(select.selects[0].alias_or_name, alias) 58 59 clause_parent_select = clause.parent_select if clause else None 60 61 if (isinstance(clause, exp.Having) and clause_parent_select is parent_select) or ( 62 (not clause or clause_parent_select is not parent_select) 63 and ( 64 parent_select.args.get("group") 65 or any(projection.find(exp.AggFunc) for projection in parent_select.selects) 66 ) 67 ): 68 column = exp.Max(this=column) 69 elif not isinstance(select.parent, exp.Subquery): 70 return 71 72 _replace(select.parent, column) 73 parent_select.join(select, join_type="CROSS", join_alias=alias, copy=False) 74 return 75 76 if select.find(exp.Limit, exp.Offset): 77 return 78 79 if isinstance(predicate, exp.Any): 80 predicate = predicate.find_ancestor(exp.EQ) 81 82 if not predicate or parent_select is not predicate.parent_select: 83 return 84 85 column = _other_operand(predicate) 86 value = select.selects[0] 87 88 join_key = exp.column(value.alias, alias) 89 join_key_not_null = join_key.is_(exp.null()).not_() 90 91 if isinstance(clause, exp.Join): 92 _replace(predicate, exp.true()) 93 parent_select.where(join_key_not_null, copy=False) 94 else: 95 _replace(predicate, join_key_not_null) 96 97 parent_select.join( 98 select.group_by(value.this, copy=False), 99 on=column.eq(join_key), 100 join_type="LEFT", 101 join_alias=alias, 102 copy=False, 103 ) 104 105 106def decorrelate(select, parent_select, external_columns, next_alias_name): 107 where = select.args.get("where") 108 109 if not where or where.find(exp.Or) or select.find(exp.Limit, exp.Offset): 110 return 111 112 table_alias = next_alias_name() 113 keys = [] 114 115 # for all external columns in the where statement, find the relevant predicate 116 # keys to convert it into a join 117 for column in external_columns: 118 if column.find_ancestor(exp.Where) is not where: 119 return 120 121 predicate = column.find_ancestor(exp.Predicate) 122 123 if not predicate or predicate.find_ancestor(exp.Where) is not where: 124 return 125 126 if isinstance(predicate, exp.Binary): 127 key = ( 128 predicate.right 129 if any(node is column for node, *_ in predicate.left.walk()) 130 else predicate.left 131 ) 132 else: 133 return 134 135 keys.append((key, column, predicate)) 136 137 if not any(isinstance(predicate, exp.EQ) for *_, predicate in keys): 138 return 139 140 is_subquery_projection = any( 141 node is select.parent for node in parent_select.selects if isinstance(node, exp.Subquery) 142 ) 143 144 value = select.selects[0] 145 key_aliases = {} 146 group_by = [] 147 148 for key, _, predicate in keys: 149 # if we filter on the value of the subquery, it needs to be unique 150 if key == value.this: 151 key_aliases[key] = value.alias 152 group_by.append(key) 153 else: 154 if key not in key_aliases: 155 key_aliases[key] = next_alias_name() 156 # all predicates that are equalities must also be in the unique 157 # so that we don't do a many to many join 158 if isinstance(predicate, exp.EQ) and key not in group_by: 159 group_by.append(key) 160 161 parent_predicate = select.find_ancestor(exp.Predicate) 162 163 # if the value of the subquery is not an agg or a key, we need to collect it into an array 164 # so that it can be grouped. For subquery projections, we use a MAX aggregation instead. 165 agg_func = exp.Max if is_subquery_projection else exp.ArrayAgg 166 if not value.find(exp.AggFunc) and value.this not in group_by: 167 select.select( 168 exp.alias_(agg_func(this=value.this), value.alias, quoted=False), 169 append=False, 170 copy=False, 171 ) 172 173 # exists queries should not have any selects as it only checks if there are any rows 174 # all selects will be added by the optimizer and only used for join keys 175 if isinstance(parent_predicate, exp.Exists): 176 select.args["expressions"] = [] 177 178 for key, alias in key_aliases.items(): 179 if key in group_by: 180 # add all keys to the projections of the subquery 181 # so that we can use it as a join key 182 if isinstance(parent_predicate, exp.Exists) or key != value.this: 183 select.select(f"{key} AS {alias}", copy=False) 184 else: 185 select.select(exp.alias_(agg_func(this=key.copy()), alias, quoted=False), copy=False) 186 187 alias = exp.column(value.alias, table_alias) 188 other = _other_operand(parent_predicate) 189 190 if isinstance(parent_predicate, exp.Exists): 191 alias = exp.column(list(key_aliases.values())[0], table_alias) 192 parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL") 193 elif isinstance(parent_predicate, exp.All): 194 parent_predicate = _replace( 195 parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})" 196 ) 197 elif isinstance(parent_predicate, exp.Any): 198 if value.this in group_by: 199 parent_predicate = _replace(parent_predicate.parent, f"{other} = {alias}") 200 else: 201 parent_predicate = _replace(parent_predicate, f"ARRAY_ANY({alias}, _x -> _x = {other})") 202 elif isinstance(parent_predicate, exp.In): 203 if value.this in group_by: 204 parent_predicate = _replace(parent_predicate, f"{other} = {alias}") 205 else: 206 parent_predicate = _replace( 207 parent_predicate, 208 f"ARRAY_ANY({alias}, _x -> _x = {parent_predicate.this})", 209 ) 210 else: 211 if is_subquery_projection: 212 alias = exp.alias_(alias, select.parent.alias) 213 214 # COUNT always returns 0 on empty datasets, so we need take that into consideration here 215 # by transforming all counts into 0 and using that as the coalesced value 216 if value.find(exp.Count): 217 218 def remove_aggs(node): 219 if isinstance(node, exp.Count): 220 return exp.Literal.number(0) 221 elif isinstance(node, exp.AggFunc): 222 return exp.null() 223 return node 224 225 alias = exp.Coalesce( 226 this=alias, 227 expressions=[value.this.transform(remove_aggs)], 228 ) 229 230 select.parent.replace(alias) 231 232 for key, column, predicate in keys: 233 predicate.replace(exp.true()) 234 nested = exp.column(key_aliases[key], table_alias) 235 236 if is_subquery_projection: 237 key.replace(nested) 238 continue 239 240 if key in group_by: 241 key.replace(nested) 242 elif isinstance(predicate, exp.EQ): 243 parent_predicate = _replace( 244 parent_predicate, 245 f"({parent_predicate} AND ARRAY_CONTAINS({nested}, {column}))", 246 ) 247 else: 248 key.replace(exp.to_identifier("_x")) 249 parent_predicate = _replace( 250 parent_predicate, 251 f'({parent_predicate} AND ARRAY_ANY({nested}, "_x" -> {predicate}))', 252 ) 253 254 parent_select.join( 255 select.group_by(*group_by, copy=False), 256 on=[predicate for *_, predicate in keys if isinstance(predicate, exp.EQ)], 257 join_type="LEFT", 258 join_alias=table_alias, 259 copy=False, 260 ) 261 262 263def _replace(expression, condition): 264 return expression.replace(exp.condition(condition)) 265 266 267def _other_operand(expression): 268 if isinstance(expression, exp.In): 269 return expression.this 270 271 if isinstance(expression, (exp.Any, exp.All)): 272 return _other_operand(expression.parent) 273 274 if isinstance(expression, exp.Binary): 275 return ( 276 expression.right 277 if isinstance(expression.left, (exp.Subquery, exp.Any, exp.Exists, exp.All)) 278 else expression.left 279 ) 280 281 return None
def
unnest_subqueries(expression):
7def unnest_subqueries(expression): 8 """ 9 Rewrite sqlglot AST to convert some predicates with subqueries into joins. 10 11 Convert scalar subqueries into cross joins. 12 Convert correlated or vectorized subqueries into a group by so it is not a many to many left join. 13 14 Example: 15 >>> import sqlglot 16 >>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ") 17 >>> unnest_subqueries(expression).sql() 18 'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE _u_0.a = 1' 19 20 Args: 21 expression (sqlglot.Expression): expression to unnest 22 Returns: 23 sqlglot.Expression: unnested expression 24 """ 25 next_alias_name = name_sequence("_u_") 26 27 for scope in traverse_scope(expression): 28 select = scope.expression 29 parent = select.parent_select 30 if not parent: 31 continue 32 if scope.external_columns: 33 decorrelate(select, parent, scope.external_columns, next_alias_name) 34 elif scope.scope_type == ScopeType.SUBQUERY: 35 unnest(select, parent, next_alias_name) 36 37 return expression
Rewrite sqlglot AST to convert some predicates with subqueries into joins.
Convert scalar subqueries into cross joins. Convert correlated or vectorized subqueries into a group by so it is not a many to many left join.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ") >>> unnest_subqueries(expression).sql() 'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE _u_0.a = 1'
Arguments:
- expression (sqlglot.Expression): expression to unnest
Returns:
sqlglot.Expression: unnested expression
def
unnest(select, parent_select, next_alias_name):
40def unnest(select, parent_select, next_alias_name): 41 if len(select.selects) > 1: 42 return 43 44 predicate = select.find_ancestor(exp.Condition) 45 alias = next_alias_name() 46 47 if ( 48 not predicate 49 or parent_select is not predicate.parent_select 50 or not parent_select.args.get("from") 51 ): 52 return 53 54 clause = predicate.find_ancestor(exp.Having, exp.Where, exp.Join) 55 56 # This subquery returns a scalar and can just be converted to a cross join 57 if not isinstance(predicate, (exp.In, exp.Any)): 58 column = exp.column(select.selects[0].alias_or_name, alias) 59 60 clause_parent_select = clause.parent_select if clause else None 61 62 if (isinstance(clause, exp.Having) and clause_parent_select is parent_select) or ( 63 (not clause or clause_parent_select is not parent_select) 64 and ( 65 parent_select.args.get("group") 66 or any(projection.find(exp.AggFunc) for projection in parent_select.selects) 67 ) 68 ): 69 column = exp.Max(this=column) 70 elif not isinstance(select.parent, exp.Subquery): 71 return 72 73 _replace(select.parent, column) 74 parent_select.join(select, join_type="CROSS", join_alias=alias, copy=False) 75 return 76 77 if select.find(exp.Limit, exp.Offset): 78 return 79 80 if isinstance(predicate, exp.Any): 81 predicate = predicate.find_ancestor(exp.EQ) 82 83 if not predicate or parent_select is not predicate.parent_select: 84 return 85 86 column = _other_operand(predicate) 87 value = select.selects[0] 88 89 join_key = exp.column(value.alias, alias) 90 join_key_not_null = join_key.is_(exp.null()).not_() 91 92 if isinstance(clause, exp.Join): 93 _replace(predicate, exp.true()) 94 parent_select.where(join_key_not_null, copy=False) 95 else: 96 _replace(predicate, join_key_not_null) 97 98 parent_select.join( 99 select.group_by(value.this, copy=False), 100 on=column.eq(join_key), 101 join_type="LEFT", 102 join_alias=alias, 103 copy=False, 104 )
def
decorrelate(select, parent_select, external_columns, next_alias_name):
107def decorrelate(select, parent_select, external_columns, next_alias_name): 108 where = select.args.get("where") 109 110 if not where or where.find(exp.Or) or select.find(exp.Limit, exp.Offset): 111 return 112 113 table_alias = next_alias_name() 114 keys = [] 115 116 # for all external columns in the where statement, find the relevant predicate 117 # keys to convert it into a join 118 for column in external_columns: 119 if column.find_ancestor(exp.Where) is not where: 120 return 121 122 predicate = column.find_ancestor(exp.Predicate) 123 124 if not predicate or predicate.find_ancestor(exp.Where) is not where: 125 return 126 127 if isinstance(predicate, exp.Binary): 128 key = ( 129 predicate.right 130 if any(node is column for node, *_ in predicate.left.walk()) 131 else predicate.left 132 ) 133 else: 134 return 135 136 keys.append((key, column, predicate)) 137 138 if not any(isinstance(predicate, exp.EQ) for *_, predicate in keys): 139 return 140 141 is_subquery_projection = any( 142 node is select.parent for node in parent_select.selects if isinstance(node, exp.Subquery) 143 ) 144 145 value = select.selects[0] 146 key_aliases = {} 147 group_by = [] 148 149 for key, _, predicate in keys: 150 # if we filter on the value of the subquery, it needs to be unique 151 if key == value.this: 152 key_aliases[key] = value.alias 153 group_by.append(key) 154 else: 155 if key not in key_aliases: 156 key_aliases[key] = next_alias_name() 157 # all predicates that are equalities must also be in the unique 158 # so that we don't do a many to many join 159 if isinstance(predicate, exp.EQ) and key not in group_by: 160 group_by.append(key) 161 162 parent_predicate = select.find_ancestor(exp.Predicate) 163 164 # if the value of the subquery is not an agg or a key, we need to collect it into an array 165 # so that it can be grouped. For subquery projections, we use a MAX aggregation instead. 166 agg_func = exp.Max if is_subquery_projection else exp.ArrayAgg 167 if not value.find(exp.AggFunc) and value.this not in group_by: 168 select.select( 169 exp.alias_(agg_func(this=value.this), value.alias, quoted=False), 170 append=False, 171 copy=False, 172 ) 173 174 # exists queries should not have any selects as it only checks if there are any rows 175 # all selects will be added by the optimizer and only used for join keys 176 if isinstance(parent_predicate, exp.Exists): 177 select.args["expressions"] = [] 178 179 for key, alias in key_aliases.items(): 180 if key in group_by: 181 # add all keys to the projections of the subquery 182 # so that we can use it as a join key 183 if isinstance(parent_predicate, exp.Exists) or key != value.this: 184 select.select(f"{key} AS {alias}", copy=False) 185 else: 186 select.select(exp.alias_(agg_func(this=key.copy()), alias, quoted=False), copy=False) 187 188 alias = exp.column(value.alias, table_alias) 189 other = _other_operand(parent_predicate) 190 191 if isinstance(parent_predicate, exp.Exists): 192 alias = exp.column(list(key_aliases.values())[0], table_alias) 193 parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL") 194 elif isinstance(parent_predicate, exp.All): 195 parent_predicate = _replace( 196 parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})" 197 ) 198 elif isinstance(parent_predicate, exp.Any): 199 if value.this in group_by: 200 parent_predicate = _replace(parent_predicate.parent, f"{other} = {alias}") 201 else: 202 parent_predicate = _replace(parent_predicate, f"ARRAY_ANY({alias}, _x -> _x = {other})") 203 elif isinstance(parent_predicate, exp.In): 204 if value.this in group_by: 205 parent_predicate = _replace(parent_predicate, f"{other} = {alias}") 206 else: 207 parent_predicate = _replace( 208 parent_predicate, 209 f"ARRAY_ANY({alias}, _x -> _x = {parent_predicate.this})", 210 ) 211 else: 212 if is_subquery_projection: 213 alias = exp.alias_(alias, select.parent.alias) 214 215 # COUNT always returns 0 on empty datasets, so we need take that into consideration here 216 # by transforming all counts into 0 and using that as the coalesced value 217 if value.find(exp.Count): 218 219 def remove_aggs(node): 220 if isinstance(node, exp.Count): 221 return exp.Literal.number(0) 222 elif isinstance(node, exp.AggFunc): 223 return exp.null() 224 return node 225 226 alias = exp.Coalesce( 227 this=alias, 228 expressions=[value.this.transform(remove_aggs)], 229 ) 230 231 select.parent.replace(alias) 232 233 for key, column, predicate in keys: 234 predicate.replace(exp.true()) 235 nested = exp.column(key_aliases[key], table_alias) 236 237 if is_subquery_projection: 238 key.replace(nested) 239 continue 240 241 if key in group_by: 242 key.replace(nested) 243 elif isinstance(predicate, exp.EQ): 244 parent_predicate = _replace( 245 parent_predicate, 246 f"({parent_predicate} AND ARRAY_CONTAINS({nested}, {column}))", 247 ) 248 else: 249 key.replace(exp.to_identifier("_x")) 250 parent_predicate = _replace( 251 parent_predicate, 252 f'({parent_predicate} AND ARRAY_ANY({nested}, "_x" -> {predicate}))', 253 ) 254 255 parent_select.join( 256 select.group_by(*group_by, copy=False), 257 on=[predicate for *_, predicate in keys if isinstance(predicate, exp.EQ)], 258 join_type="LEFT", 259 join_alias=table_alias, 260 copy=False, 261 )