sqlglot.optimizer.unnest_subqueries
1import itertools 2 3from sqlglot import exp 4from sqlglot.optimizer.scope import ScopeType, traverse_scope 5 6 7def unnest_subqueries(expression): 8 """ 9 Rewrite sqlglot AST to convert some predicates with subqueries into joins. 10 11 Convert scalar subqueries into cross joins. 12 Convert correlated or vectorized subqueries into a group by so it is not a many to many left join. 13 14 Example: 15 >>> import sqlglot 16 >>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ") 17 >>> unnest_subqueries(expression).sql() 18 'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE _u_0.a = 1' 19 20 Args: 21 expression (sqlglot.Expression): expression to unnest 22 Returns: 23 sqlglot.Expression: unnested expression 24 """ 25 sequence = itertools.count() 26 27 for scope in traverse_scope(expression): 28 select = scope.expression 29 parent = select.parent_select 30 if not parent: 31 continue 32 if scope.external_columns: 33 decorrelate(select, parent, scope.external_columns, sequence) 34 elif scope.scope_type == ScopeType.SUBQUERY: 35 unnest(select, parent, sequence) 36 37 return expression 38 39 40def unnest(select, parent_select, sequence): 41 if len(select.selects) > 1: 42 return 43 44 predicate = select.find_ancestor(exp.Condition) 45 alias = _alias(sequence) 46 47 if not predicate or parent_select is not predicate.parent_select: 48 return 49 50 # this subquery returns a scalar and can just be converted to a cross join 51 if not isinstance(predicate, (exp.In, exp.Any)): 52 having = predicate.find_ancestor(exp.Having) 53 column = exp.column(select.selects[0].alias_or_name, alias) 54 if having and having.parent_select is parent_select: 55 column = exp.Max(this=column) 56 _replace(select.parent, column) 57 58 parent_select.join( 59 select, 60 join_type="CROSS", 61 join_alias=alias, 62 copy=False, 63 ) 64 return 65 66 if select.find(exp.Limit, exp.Offset): 67 return 68 69 if isinstance(predicate, exp.Any): 70 predicate = predicate.find_ancestor(exp.EQ) 71 72 if not predicate or parent_select is not predicate.parent_select: 73 return 74 75 column = _other_operand(predicate) 76 value = select.selects[0] 77 78 on = exp.condition(f'{column} = "{alias}"."{value.alias}"') 79 _replace(predicate, f"NOT {on.right} IS NULL") 80 81 parent_select.join( 82 select.group_by(value.this, copy=False), 83 on=on, 84 join_type="LEFT", 85 join_alias=alias, 86 copy=False, 87 ) 88 89 90def decorrelate(select, parent_select, external_columns, sequence): 91 where = select.args.get("where") 92 93 if not where or where.find(exp.Or) or select.find(exp.Limit, exp.Offset): 94 return 95 96 table_alias = _alias(sequence) 97 keys = [] 98 99 # for all external columns in the where statement, find the relevant predicate 100 # keys to convert it into a join 101 for column in external_columns: 102 if column.find_ancestor(exp.Where) is not where: 103 return 104 105 predicate = column.find_ancestor(exp.Predicate) 106 107 if not predicate or predicate.find_ancestor(exp.Where) is not where: 108 return 109 110 if isinstance(predicate, exp.Binary): 111 key = ( 112 predicate.right 113 if any(node is column for node, *_ in predicate.left.walk()) 114 else predicate.left 115 ) 116 else: 117 return 118 119 keys.append((key, column, predicate)) 120 121 if not any(isinstance(predicate, exp.EQ) for *_, predicate in keys): 122 return 123 124 is_subquery_projection = any( 125 node is select.parent for node in parent_select.selects if isinstance(node, exp.Subquery) 126 ) 127 128 value = select.selects[0] 129 key_aliases = {} 130 group_by = [] 131 132 for key, _, predicate in keys: 133 # if we filter on the value of the subquery, it needs to be unique 134 if key == value.this: 135 key_aliases[key] = value.alias 136 group_by.append(key) 137 else: 138 if key not in key_aliases: 139 key_aliases[key] = _alias(sequence) 140 # all predicates that are equalities must also be in the unique 141 # so that we don't do a many to many join 142 if isinstance(predicate, exp.EQ) and key not in group_by: 143 group_by.append(key) 144 145 parent_predicate = select.find_ancestor(exp.Predicate) 146 147 # if the value of the subquery is not an agg or a key, we need to collect it into an array 148 # so that it can be grouped. For subquery projections, we use a MAX aggregation instead. 149 agg_func = exp.Max if is_subquery_projection else exp.ArrayAgg 150 if not value.find(exp.AggFunc) and value.this not in group_by: 151 select.select( 152 exp.alias_(agg_func(this=value.this), value.alias, quoted=False), 153 append=False, 154 copy=False, 155 ) 156 157 # exists queries should not have any selects as it only checks if there are any rows 158 # all selects will be added by the optimizer and only used for join keys 159 if isinstance(parent_predicate, exp.Exists): 160 select.args["expressions"] = [] 161 162 for key, alias in key_aliases.items(): 163 if key in group_by: 164 # add all keys to the projections of the subquery 165 # so that we can use it as a join key 166 if isinstance(parent_predicate, exp.Exists) or key != value.this: 167 select.select(f"{key} AS {alias}", copy=False) 168 else: 169 select.select(exp.alias_(agg_func(this=key.copy()), alias, quoted=False), copy=False) 170 171 alias = exp.column(value.alias, table_alias) 172 other = _other_operand(parent_predicate) 173 174 if isinstance(parent_predicate, exp.Exists): 175 alias = exp.column(list(key_aliases.values())[0], table_alias) 176 parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL") 177 elif isinstance(parent_predicate, exp.All): 178 parent_predicate = _replace( 179 parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})" 180 ) 181 elif isinstance(parent_predicate, exp.Any): 182 if value.this in group_by: 183 parent_predicate = _replace(parent_predicate.parent, f"{other} = {alias}") 184 else: 185 parent_predicate = _replace(parent_predicate, f"ARRAY_ANY({alias}, _x -> _x = {other})") 186 elif isinstance(parent_predicate, exp.In): 187 if value.this in group_by: 188 parent_predicate = _replace(parent_predicate, f"{other} = {alias}") 189 else: 190 parent_predicate = _replace( 191 parent_predicate, 192 f"ARRAY_ANY({alias}, _x -> _x = {parent_predicate.this})", 193 ) 194 else: 195 if is_subquery_projection: 196 alias = exp.alias_(alias, select.parent.alias) 197 198 # COUNT always returns 0 on empty datasets, so we need take that into consideration here 199 # by transforming all counts into 0 and using that as the coalesced value 200 if value.find(exp.Count): 201 202 def remove_aggs(node): 203 if isinstance(node, exp.Count): 204 return exp.Literal.number(0) 205 elif isinstance(node, exp.AggFunc): 206 return exp.null() 207 return node 208 209 alias = exp.Coalesce( 210 this=alias, 211 expressions=[value.this.transform(remove_aggs)], 212 ) 213 214 select.parent.replace(alias) 215 216 for key, column, predicate in keys: 217 predicate.replace(exp.true()) 218 nested = exp.column(key_aliases[key], table_alias) 219 220 if is_subquery_projection: 221 key.replace(nested) 222 continue 223 224 if key in group_by: 225 key.replace(nested) 226 elif isinstance(predicate, exp.EQ): 227 parent_predicate = _replace( 228 parent_predicate, 229 f"({parent_predicate} AND ARRAY_CONTAINS({nested}, {column}))", 230 ) 231 else: 232 key.replace(exp.to_identifier("_x")) 233 parent_predicate = _replace( 234 parent_predicate, 235 f'({parent_predicate} AND ARRAY_ANY({nested}, "_x" -> {predicate}))', 236 ) 237 238 parent_select.join( 239 select.group_by(*group_by, copy=False), 240 on=[predicate for *_, predicate in keys if isinstance(predicate, exp.EQ)], 241 join_type="LEFT", 242 join_alias=table_alias, 243 copy=False, 244 ) 245 246 247def _alias(sequence): 248 return f"_u_{next(sequence)}" 249 250 251def _replace(expression, condition): 252 return expression.replace(exp.condition(condition)) 253 254 255def _other_operand(expression): 256 if isinstance(expression, exp.In): 257 return expression.this 258 259 if isinstance(expression, (exp.Any, exp.All)): 260 return _other_operand(expression.parent) 261 262 if isinstance(expression, exp.Binary): 263 return ( 264 expression.right 265 if isinstance(expression.left, (exp.Subquery, exp.Any, exp.Exists, exp.All)) 266 else expression.left 267 ) 268 269 return None
def
unnest_subqueries(expression):
8def unnest_subqueries(expression): 9 """ 10 Rewrite sqlglot AST to convert some predicates with subqueries into joins. 11 12 Convert scalar subqueries into cross joins. 13 Convert correlated or vectorized subqueries into a group by so it is not a many to many left join. 14 15 Example: 16 >>> import sqlglot 17 >>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ") 18 >>> unnest_subqueries(expression).sql() 19 'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE _u_0.a = 1' 20 21 Args: 22 expression (sqlglot.Expression): expression to unnest 23 Returns: 24 sqlglot.Expression: unnested expression 25 """ 26 sequence = itertools.count() 27 28 for scope in traverse_scope(expression): 29 select = scope.expression 30 parent = select.parent_select 31 if not parent: 32 continue 33 if scope.external_columns: 34 decorrelate(select, parent, scope.external_columns, sequence) 35 elif scope.scope_type == ScopeType.SUBQUERY: 36 unnest(select, parent, sequence) 37 38 return expression
Rewrite sqlglot AST to convert some predicates with subqueries into joins.
Convert scalar subqueries into cross joins. Convert correlated or vectorized subqueries into a group by so it is not a many to many left join.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ") >>> unnest_subqueries(expression).sql() 'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE _u_0.a = 1'
Arguments:
- expression (sqlglot.Expression): expression to unnest
Returns:
sqlglot.Expression: unnested expression
def
unnest(select, parent_select, sequence):
41def unnest(select, parent_select, sequence): 42 if len(select.selects) > 1: 43 return 44 45 predicate = select.find_ancestor(exp.Condition) 46 alias = _alias(sequence) 47 48 if not predicate or parent_select is not predicate.parent_select: 49 return 50 51 # this subquery returns a scalar and can just be converted to a cross join 52 if not isinstance(predicate, (exp.In, exp.Any)): 53 having = predicate.find_ancestor(exp.Having) 54 column = exp.column(select.selects[0].alias_or_name, alias) 55 if having and having.parent_select is parent_select: 56 column = exp.Max(this=column) 57 _replace(select.parent, column) 58 59 parent_select.join( 60 select, 61 join_type="CROSS", 62 join_alias=alias, 63 copy=False, 64 ) 65 return 66 67 if select.find(exp.Limit, exp.Offset): 68 return 69 70 if isinstance(predicate, exp.Any): 71 predicate = predicate.find_ancestor(exp.EQ) 72 73 if not predicate or parent_select is not predicate.parent_select: 74 return 75 76 column = _other_operand(predicate) 77 value = select.selects[0] 78 79 on = exp.condition(f'{column} = "{alias}"."{value.alias}"') 80 _replace(predicate, f"NOT {on.right} IS NULL") 81 82 parent_select.join( 83 select.group_by(value.this, copy=False), 84 on=on, 85 join_type="LEFT", 86 join_alias=alias, 87 copy=False, 88 )
def
decorrelate(select, parent_select, external_columns, sequence):
91def decorrelate(select, parent_select, external_columns, sequence): 92 where = select.args.get("where") 93 94 if not where or where.find(exp.Or) or select.find(exp.Limit, exp.Offset): 95 return 96 97 table_alias = _alias(sequence) 98 keys = [] 99 100 # for all external columns in the where statement, find the relevant predicate 101 # keys to convert it into a join 102 for column in external_columns: 103 if column.find_ancestor(exp.Where) is not where: 104 return 105 106 predicate = column.find_ancestor(exp.Predicate) 107 108 if not predicate or predicate.find_ancestor(exp.Where) is not where: 109 return 110 111 if isinstance(predicate, exp.Binary): 112 key = ( 113 predicate.right 114 if any(node is column for node, *_ in predicate.left.walk()) 115 else predicate.left 116 ) 117 else: 118 return 119 120 keys.append((key, column, predicate)) 121 122 if not any(isinstance(predicate, exp.EQ) for *_, predicate in keys): 123 return 124 125 is_subquery_projection = any( 126 node is select.parent for node in parent_select.selects if isinstance(node, exp.Subquery) 127 ) 128 129 value = select.selects[0] 130 key_aliases = {} 131 group_by = [] 132 133 for key, _, predicate in keys: 134 # if we filter on the value of the subquery, it needs to be unique 135 if key == value.this: 136 key_aliases[key] = value.alias 137 group_by.append(key) 138 else: 139 if key not in key_aliases: 140 key_aliases[key] = _alias(sequence) 141 # all predicates that are equalities must also be in the unique 142 # so that we don't do a many to many join 143 if isinstance(predicate, exp.EQ) and key not in group_by: 144 group_by.append(key) 145 146 parent_predicate = select.find_ancestor(exp.Predicate) 147 148 # if the value of the subquery is not an agg or a key, we need to collect it into an array 149 # so that it can be grouped. For subquery projections, we use a MAX aggregation instead. 150 agg_func = exp.Max if is_subquery_projection else exp.ArrayAgg 151 if not value.find(exp.AggFunc) and value.this not in group_by: 152 select.select( 153 exp.alias_(agg_func(this=value.this), value.alias, quoted=False), 154 append=False, 155 copy=False, 156 ) 157 158 # exists queries should not have any selects as it only checks if there are any rows 159 # all selects will be added by the optimizer and only used for join keys 160 if isinstance(parent_predicate, exp.Exists): 161 select.args["expressions"] = [] 162 163 for key, alias in key_aliases.items(): 164 if key in group_by: 165 # add all keys to the projections of the subquery 166 # so that we can use it as a join key 167 if isinstance(parent_predicate, exp.Exists) or key != value.this: 168 select.select(f"{key} AS {alias}", copy=False) 169 else: 170 select.select(exp.alias_(agg_func(this=key.copy()), alias, quoted=False), copy=False) 171 172 alias = exp.column(value.alias, table_alias) 173 other = _other_operand(parent_predicate) 174 175 if isinstance(parent_predicate, exp.Exists): 176 alias = exp.column(list(key_aliases.values())[0], table_alias) 177 parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL") 178 elif isinstance(parent_predicate, exp.All): 179 parent_predicate = _replace( 180 parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})" 181 ) 182 elif isinstance(parent_predicate, exp.Any): 183 if value.this in group_by: 184 parent_predicate = _replace(parent_predicate.parent, f"{other} = {alias}") 185 else: 186 parent_predicate = _replace(parent_predicate, f"ARRAY_ANY({alias}, _x -> _x = {other})") 187 elif isinstance(parent_predicate, exp.In): 188 if value.this in group_by: 189 parent_predicate = _replace(parent_predicate, f"{other} = {alias}") 190 else: 191 parent_predicate = _replace( 192 parent_predicate, 193 f"ARRAY_ANY({alias}, _x -> _x = {parent_predicate.this})", 194 ) 195 else: 196 if is_subquery_projection: 197 alias = exp.alias_(alias, select.parent.alias) 198 199 # COUNT always returns 0 on empty datasets, so we need take that into consideration here 200 # by transforming all counts into 0 and using that as the coalesced value 201 if value.find(exp.Count): 202 203 def remove_aggs(node): 204 if isinstance(node, exp.Count): 205 return exp.Literal.number(0) 206 elif isinstance(node, exp.AggFunc): 207 return exp.null() 208 return node 209 210 alias = exp.Coalesce( 211 this=alias, 212 expressions=[value.this.transform(remove_aggs)], 213 ) 214 215 select.parent.replace(alias) 216 217 for key, column, predicate in keys: 218 predicate.replace(exp.true()) 219 nested = exp.column(key_aliases[key], table_alias) 220 221 if is_subquery_projection: 222 key.replace(nested) 223 continue 224 225 if key in group_by: 226 key.replace(nested) 227 elif isinstance(predicate, exp.EQ): 228 parent_predicate = _replace( 229 parent_predicate, 230 f"({parent_predicate} AND ARRAY_CONTAINS({nested}, {column}))", 231 ) 232 else: 233 key.replace(exp.to_identifier("_x")) 234 parent_predicate = _replace( 235 parent_predicate, 236 f'({parent_predicate} AND ARRAY_ANY({nested}, "_x" -> {predicate}))', 237 ) 238 239 parent_select.join( 240 select.group_by(*group_by, copy=False), 241 on=[predicate for *_, predicate in keys if isinstance(predicate, exp.EQ)], 242 join_type="LEFT", 243 join_alias=table_alias, 244 copy=False, 245 )