sqlglot.optimizer.merge_subqueries
1from collections import defaultdict 2 3from sqlglot import expressions as exp 4from sqlglot.helper import find_new_name 5from sqlglot.optimizer.scope import Scope, traverse_scope 6 7 8def merge_subqueries(expression, leave_tables_isolated=False): 9 """ 10 Rewrite sqlglot AST to merge derived tables into the outer query. 11 12 This also merges CTEs if they are selected from only once. 13 14 Example: 15 >>> import sqlglot 16 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y") 17 >>> merge_subqueries(expression).sql() 18 'SELECT x.a FROM x CROSS JOIN y' 19 20 If `leave_tables_isolated` is True, this will not merge inner queries into outer 21 queries if it would result in multiple table selects in a single query: 22 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y") 23 >>> merge_subqueries(expression, leave_tables_isolated=True).sql() 24 'SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y' 25 26 Inspired by https://dev.mysql.com/doc/refman/8.0/en/derived-table-optimization.html 27 28 Args: 29 expression (sqlglot.Expression): expression to optimize 30 leave_tables_isolated (bool): 31 Returns: 32 sqlglot.Expression: optimized expression 33 """ 34 expression = merge_ctes(expression, leave_tables_isolated) 35 expression = merge_derived_tables(expression, leave_tables_isolated) 36 return expression 37 38 39# If a derived table has these Select args, it can't be merged 40UNMERGABLE_ARGS = set(exp.Select.arg_types) - { 41 "expressions", 42 "from", 43 "joins", 44 "where", 45 "order", 46 "hint", 47} 48 49 50# Projections in the outer query that are instances of these types can be replaced 51# without getting wrapped in parentheses, because the precedence won't be altered. 52SAFE_TO_REPLACE_UNWRAPPED = ( 53 exp.Column, 54 exp.EQ, 55 exp.Func, 56 exp.NEQ, 57 exp.Paren, 58) 59 60 61def merge_ctes(expression, leave_tables_isolated=False): 62 scopes = traverse_scope(expression) 63 64 # All places where we select from CTEs. 65 # We key on the CTE scope so we can detect CTES that are selected from multiple times. 66 cte_selections = defaultdict(list) 67 for outer_scope in scopes: 68 for table, inner_scope in outer_scope.selected_sources.values(): 69 if isinstance(inner_scope, Scope) and inner_scope.is_cte: 70 cte_selections[id(inner_scope)].append( 71 ( 72 outer_scope, 73 inner_scope, 74 table, 75 ) 76 ) 77 78 singular_cte_selections = [v[0] for k, v in cte_selections.items() if len(v) == 1] 79 for outer_scope, inner_scope, table in singular_cte_selections: 80 from_or_join = table.find_ancestor(exp.From, exp.Join) 81 if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join): 82 alias = table.alias_or_name 83 _rename_inner_sources(outer_scope, inner_scope, alias) 84 _merge_from(outer_scope, inner_scope, table, alias) 85 _merge_expressions(outer_scope, inner_scope, alias) 86 _merge_joins(outer_scope, inner_scope, from_or_join) 87 _merge_where(outer_scope, inner_scope, from_or_join) 88 _merge_order(outer_scope, inner_scope) 89 _merge_hints(outer_scope, inner_scope) 90 _pop_cte(inner_scope) 91 outer_scope.clear_cache() 92 return expression 93 94 95def merge_derived_tables(expression, leave_tables_isolated=False): 96 for outer_scope in traverse_scope(expression): 97 for subquery in outer_scope.derived_tables: 98 from_or_join = subquery.find_ancestor(exp.From, exp.Join) 99 alias = subquery.alias_or_name 100 inner_scope = outer_scope.sources[alias] 101 if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join): 102 _rename_inner_sources(outer_scope, inner_scope, alias) 103 _merge_from(outer_scope, inner_scope, subquery, alias) 104 _merge_expressions(outer_scope, inner_scope, alias) 105 _merge_joins(outer_scope, inner_scope, from_or_join) 106 _merge_where(outer_scope, inner_scope, from_or_join) 107 _merge_order(outer_scope, inner_scope) 108 _merge_hints(outer_scope, inner_scope) 109 outer_scope.clear_cache() 110 111 return expression 112 113 114def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join): 115 """ 116 Return True if `inner_select` can be merged into outer query. 117 118 Args: 119 outer_scope (Scope) 120 inner_scope (Scope) 121 leave_tables_isolated (bool) 122 from_or_join (exp.From|exp.Join) 123 Returns: 124 bool: True if can be merged 125 """ 126 inner_select = inner_scope.expression.unnest() 127 128 def _is_a_window_expression_in_unmergable_operation(): 129 window_expressions = inner_select.find_all(exp.Window) 130 window_alias_names = {window.parent.alias_or_name for window in window_expressions} 131 inner_select_name = from_or_join.alias_or_name 132 unmergable_window_columns = [ 133 column 134 for column in outer_scope.columns 135 if column.find_ancestor( 136 exp.Where, exp.Group, exp.Order, exp.Join, exp.Having, exp.AggFunc 137 ) 138 ] 139 window_expressions_in_unmergable = [ 140 column 141 for column in unmergable_window_columns 142 if column.table == inner_select_name and column.name in window_alias_names 143 ] 144 return any(window_expressions_in_unmergable) 145 146 def _outer_select_joins_on_inner_select_join(): 147 """ 148 All columns from the inner select in the ON clause must be from the first FROM table. 149 150 That is, this can be merged: 151 SELECT * FROM x JOIN (SELECT y.a AS a FROM y JOIN z) AS q ON x.a = q.a 152 ^^^ ^ 153 But this can't: 154 SELECT * FROM x JOIN (SELECT z.a AS a FROM y JOIN z) AS q ON x.a = q.a 155 ^^^ ^ 156 """ 157 if not isinstance(from_or_join, exp.Join): 158 return False 159 160 alias = from_or_join.alias_or_name 161 162 on = from_or_join.args.get("on") 163 if not on: 164 return False 165 selections = [c.name for c in on.find_all(exp.Column) if c.table == alias] 166 inner_from = inner_scope.expression.args.get("from") 167 if not inner_from: 168 return False 169 inner_from_table = inner_from.alias_or_name 170 inner_projections = {s.alias_or_name: s for s in inner_scope.expression.selects} 171 return any( 172 col.table != inner_from_table 173 for selection in selections 174 for col in inner_projections[selection].find_all(exp.Column) 175 ) 176 177 def _is_recursive(): 178 # Recursive CTEs look like this: 179 # WITH RECURSIVE cte AS ( 180 # SELECT * FROM x <-- inner scope 181 # UNION ALL 182 # SELECT * FROM cte <-- outer scope 183 # ) 184 cte = inner_scope.expression.parent 185 node = outer_scope.expression.parent 186 187 while node: 188 if node is cte: 189 return True 190 node = node.parent 191 return False 192 193 return ( 194 isinstance(outer_scope.expression, exp.Select) 195 and not outer_scope.expression.is_star 196 and isinstance(inner_select, exp.Select) 197 and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS) 198 and inner_select.args.get("from") 199 and not outer_scope.pivots 200 and not any(e.find(exp.AggFunc, exp.Select, exp.Explode) for e in inner_select.expressions) 201 and not (leave_tables_isolated and len(outer_scope.selected_sources) > 1) 202 and not ( 203 isinstance(from_or_join, exp.Join) 204 and inner_select.args.get("where") 205 and from_or_join.side in ("FULL", "LEFT", "RIGHT") 206 ) 207 and not ( 208 isinstance(from_or_join, exp.From) 209 and inner_select.args.get("where") 210 and any( 211 j.side in ("FULL", "RIGHT") for j in outer_scope.expression.args.get("joins", []) 212 ) 213 ) 214 and not _outer_select_joins_on_inner_select_join() 215 and not _is_a_window_expression_in_unmergable_operation() 216 and not _is_recursive() 217 and not (inner_select.args.get("order") and outer_scope.is_union) 218 ) 219 220 221def _rename_inner_sources(outer_scope, inner_scope, alias): 222 """ 223 Renames any sources in the inner query that conflict with names in the outer query. 224 225 Args: 226 outer_scope (sqlglot.optimizer.scope.Scope) 227 inner_scope (sqlglot.optimizer.scope.Scope) 228 alias (str) 229 """ 230 taken = set(outer_scope.selected_sources) 231 conflicts = taken.intersection(set(inner_scope.selected_sources)) 232 conflicts -= {alias} 233 234 for conflict in conflicts: 235 new_name = find_new_name(taken, conflict) 236 237 source, _ = inner_scope.selected_sources[conflict] 238 new_alias = exp.to_identifier(new_name) 239 240 if isinstance(source, exp.Subquery): 241 source.set("alias", exp.TableAlias(this=new_alias)) 242 elif isinstance(source, exp.Table) and source.alias: 243 source.set("alias", new_alias) 244 elif isinstance(source, exp.Table): 245 source.replace(exp.alias_(source, new_alias)) 246 247 for column in inner_scope.source_columns(conflict): 248 column.set("table", exp.to_identifier(new_name)) 249 250 inner_scope.rename_source(conflict, new_name) 251 252 253def _merge_from(outer_scope, inner_scope, node_to_replace, alias): 254 """ 255 Merge FROM clause of inner query into outer query. 256 257 Args: 258 outer_scope (sqlglot.optimizer.scope.Scope) 259 inner_scope (sqlglot.optimizer.scope.Scope) 260 node_to_replace (exp.Subquery|exp.Table) 261 alias (str) 262 """ 263 new_subquery = inner_scope.expression.args["from"].this 264 new_subquery.set("joins", node_to_replace.args.get("joins")) 265 node_to_replace.replace(new_subquery) 266 for join_hint in outer_scope.join_hints: 267 tables = join_hint.find_all(exp.Table) 268 for table in tables: 269 if table.alias_or_name == node_to_replace.alias_or_name: 270 table.set("this", exp.to_identifier(new_subquery.alias_or_name)) 271 outer_scope.remove_source(alias) 272 outer_scope.add_source( 273 new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name] 274 ) 275 276 277def _merge_joins(outer_scope, inner_scope, from_or_join): 278 """ 279 Merge JOIN clauses of inner query into outer query. 280 281 Args: 282 outer_scope (sqlglot.optimizer.scope.Scope) 283 inner_scope (sqlglot.optimizer.scope.Scope) 284 from_or_join (exp.From|exp.Join) 285 """ 286 287 new_joins = [] 288 289 joins = inner_scope.expression.args.get("joins") or [] 290 for join in joins: 291 new_joins.append(join) 292 outer_scope.add_source(join.alias_or_name, inner_scope.sources[join.alias_or_name]) 293 294 if new_joins: 295 outer_joins = outer_scope.expression.args.get("joins", []) 296 297 # Maintain the join order 298 if isinstance(from_or_join, exp.From): 299 position = 0 300 else: 301 position = outer_joins.index(from_or_join) + 1 302 outer_joins[position:position] = new_joins 303 304 outer_scope.expression.set("joins", outer_joins) 305 306 307def _merge_expressions(outer_scope, inner_scope, alias): 308 """ 309 Merge projections of inner query into outer query. 310 311 Args: 312 outer_scope (sqlglot.optimizer.scope.Scope) 313 inner_scope (sqlglot.optimizer.scope.Scope) 314 alias (str) 315 """ 316 # Collect all columns that reference the alias of the inner query 317 outer_columns = defaultdict(list) 318 for column in outer_scope.columns: 319 if column.table == alias: 320 outer_columns[column.name].append(column) 321 322 # Replace columns with the projection expression in the inner query 323 for expression in inner_scope.expression.expressions: 324 projection_name = expression.alias_or_name 325 if not projection_name: 326 continue 327 columns_to_replace = outer_columns.get(projection_name, []) 328 329 expression = expression.unalias() 330 must_wrap_expression = not isinstance(expression, SAFE_TO_REPLACE_UNWRAPPED) 331 332 for column in columns_to_replace: 333 # Ensures we don't alter the intended operator precedence if there's additional 334 # context surrounding the outer expression (i.e. it's not a simple projection). 335 if isinstance(column.parent, (exp.Unary, exp.Binary)) and must_wrap_expression: 336 expression = exp.paren(expression, copy=False) 337 338 column.replace(expression.copy()) 339 340 341def _merge_where(outer_scope, inner_scope, from_or_join): 342 """ 343 Merge WHERE clause of inner query into outer query. 344 345 Args: 346 outer_scope (sqlglot.optimizer.scope.Scope) 347 inner_scope (sqlglot.optimizer.scope.Scope) 348 from_or_join (exp.From|exp.Join) 349 """ 350 where = inner_scope.expression.args.get("where") 351 if not where or not where.this: 352 return 353 354 expression = outer_scope.expression 355 356 if isinstance(from_or_join, exp.Join): 357 # Merge predicates from an outer join to the ON clause 358 # if it only has columns that are already joined 359 from_ = expression.args.get("from") 360 sources = {from_.alias_or_name} if from_ else {} 361 362 for join in expression.args["joins"]: 363 source = join.alias_or_name 364 sources.add(source) 365 if source == from_or_join.alias_or_name: 366 break 367 368 if exp.column_table_names(where.this) <= sources: 369 from_or_join.on(where.this, copy=False) 370 from_or_join.set("on", from_or_join.args.get("on")) 371 return 372 373 expression.where(where.this, copy=False) 374 375 376def _merge_order(outer_scope, inner_scope): 377 """ 378 Merge ORDER clause of inner query into outer query. 379 380 Args: 381 outer_scope (sqlglot.optimizer.scope.Scope) 382 inner_scope (sqlglot.optimizer.scope.Scope) 383 """ 384 if ( 385 any( 386 outer_scope.expression.args.get(arg) for arg in ["group", "distinct", "having", "order"] 387 ) 388 or len(outer_scope.selected_sources) != 1 389 or any(expression.find(exp.AggFunc) for expression in outer_scope.expression.expressions) 390 ): 391 return 392 393 outer_scope.expression.set("order", inner_scope.expression.args.get("order")) 394 395 396def _merge_hints(outer_scope, inner_scope): 397 inner_scope_hint = inner_scope.expression.args.get("hint") 398 if not inner_scope_hint: 399 return 400 outer_scope_hint = outer_scope.expression.args.get("hint") 401 if outer_scope_hint: 402 for hint_expression in inner_scope_hint.expressions: 403 outer_scope_hint.append("expressions", hint_expression) 404 else: 405 outer_scope.expression.set("hint", inner_scope_hint) 406 407 408def _pop_cte(inner_scope): 409 """ 410 Remove CTE from the AST. 411 412 Args: 413 inner_scope (sqlglot.optimizer.scope.Scope) 414 """ 415 cte = inner_scope.expression.parent 416 with_ = cte.parent 417 if len(with_.expressions) == 1: 418 with_.pop() 419 else: 420 cte.pop()
def
merge_subqueries(expression, leave_tables_isolated=False):
9def merge_subqueries(expression, leave_tables_isolated=False): 10 """ 11 Rewrite sqlglot AST to merge derived tables into the outer query. 12 13 This also merges CTEs if they are selected from only once. 14 15 Example: 16 >>> import sqlglot 17 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y") 18 >>> merge_subqueries(expression).sql() 19 'SELECT x.a FROM x CROSS JOIN y' 20 21 If `leave_tables_isolated` is True, this will not merge inner queries into outer 22 queries if it would result in multiple table selects in a single query: 23 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y") 24 >>> merge_subqueries(expression, leave_tables_isolated=True).sql() 25 'SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y' 26 27 Inspired by https://dev.mysql.com/doc/refman/8.0/en/derived-table-optimization.html 28 29 Args: 30 expression (sqlglot.Expression): expression to optimize 31 leave_tables_isolated (bool): 32 Returns: 33 sqlglot.Expression: optimized expression 34 """ 35 expression = merge_ctes(expression, leave_tables_isolated) 36 expression = merge_derived_tables(expression, leave_tables_isolated) 37 return expression
Rewrite sqlglot AST to merge derived tables into the outer query.
This also merges CTEs if they are selected from only once.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y") >>> merge_subqueries(expression).sql() 'SELECT x.a FROM x CROSS JOIN y'
If leave_tables_isolated
is True, this will not merge inner queries into outer
queries if it would result in multiple table selects in a single query:
expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y") merge_subqueries(expression, leave_tables_isolated=True).sql() 'SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y'
Inspired by https://dev.mysql.com/doc/refman/8.0/en/derived-table-optimization.html
Arguments:
- expression (sqlglot.Expression): expression to optimize
- leave_tables_isolated (bool):
Returns:
sqlglot.Expression: optimized expression
UNMERGABLE_ARGS =
{'offset', 'prewhere', 'match', 'locks', 'qualify', 'windows', 'pivots', 'cluster', 'settings', 'having', 'group', 'options', 'distinct', 'with', 'distribute', 'sample', 'format', 'connect', 'laterals', 'limit', 'sort', 'into', 'kind'}
SAFE_TO_REPLACE_UNWRAPPED =
(<class 'sqlglot.expressions.Column'>, <class 'sqlglot.expressions.EQ'>, <class 'sqlglot.expressions.Func'>, <class 'sqlglot.expressions.NEQ'>, <class 'sqlglot.expressions.Paren'>)
def
merge_ctes(expression, leave_tables_isolated=False):
62def merge_ctes(expression, leave_tables_isolated=False): 63 scopes = traverse_scope(expression) 64 65 # All places where we select from CTEs. 66 # We key on the CTE scope so we can detect CTES that are selected from multiple times. 67 cte_selections = defaultdict(list) 68 for outer_scope in scopes: 69 for table, inner_scope in outer_scope.selected_sources.values(): 70 if isinstance(inner_scope, Scope) and inner_scope.is_cte: 71 cte_selections[id(inner_scope)].append( 72 ( 73 outer_scope, 74 inner_scope, 75 table, 76 ) 77 ) 78 79 singular_cte_selections = [v[0] for k, v in cte_selections.items() if len(v) == 1] 80 for outer_scope, inner_scope, table in singular_cte_selections: 81 from_or_join = table.find_ancestor(exp.From, exp.Join) 82 if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join): 83 alias = table.alias_or_name 84 _rename_inner_sources(outer_scope, inner_scope, alias) 85 _merge_from(outer_scope, inner_scope, table, alias) 86 _merge_expressions(outer_scope, inner_scope, alias) 87 _merge_joins(outer_scope, inner_scope, from_or_join) 88 _merge_where(outer_scope, inner_scope, from_or_join) 89 _merge_order(outer_scope, inner_scope) 90 _merge_hints(outer_scope, inner_scope) 91 _pop_cte(inner_scope) 92 outer_scope.clear_cache() 93 return expression
def
merge_derived_tables(expression, leave_tables_isolated=False):
96def merge_derived_tables(expression, leave_tables_isolated=False): 97 for outer_scope in traverse_scope(expression): 98 for subquery in outer_scope.derived_tables: 99 from_or_join = subquery.find_ancestor(exp.From, exp.Join) 100 alias = subquery.alias_or_name 101 inner_scope = outer_scope.sources[alias] 102 if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join): 103 _rename_inner_sources(outer_scope, inner_scope, alias) 104 _merge_from(outer_scope, inner_scope, subquery, alias) 105 _merge_expressions(outer_scope, inner_scope, alias) 106 _merge_joins(outer_scope, inner_scope, from_or_join) 107 _merge_where(outer_scope, inner_scope, from_or_join) 108 _merge_order(outer_scope, inner_scope) 109 _merge_hints(outer_scope, inner_scope) 110 outer_scope.clear_cache() 111 112 return expression