sqlglot.optimizer.merge_subqueries
1from collections import defaultdict 2 3from sqlglot import expressions as exp 4from sqlglot.helper import find_new_name 5from sqlglot.optimizer.scope import Scope, traverse_scope 6 7 8def merge_subqueries(expression, leave_tables_isolated=False): 9 """ 10 Rewrite sqlglot AST to merge derived tables into the outer query. 11 12 This also merges CTEs if they are selected from only once. 13 14 Example: 15 >>> import sqlglot 16 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) JOIN y") 17 >>> merge_subqueries(expression).sql() 18 'SELECT x.a FROM x JOIN y' 19 20 If `leave_tables_isolated` is True, this will not merge inner queries into outer 21 queries if it would result in multiple table selects in a single query: 22 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) JOIN y") 23 >>> merge_subqueries(expression, leave_tables_isolated=True).sql() 24 'SELECT a FROM (SELECT x.a FROM x) JOIN y' 25 26 Inspired by https://dev.mysql.com/doc/refman/8.0/en/derived-table-optimization.html 27 28 Args: 29 expression (sqlglot.Expression): expression to optimize 30 leave_tables_isolated (bool): 31 Returns: 32 sqlglot.Expression: optimized expression 33 """ 34 expression = merge_ctes(expression, leave_tables_isolated) 35 expression = merge_derived_tables(expression, leave_tables_isolated) 36 return expression 37 38 39# If a derived table has these Select args, it can't be merged 40UNMERGABLE_ARGS = set(exp.Select.arg_types) - { 41 "expressions", 42 "from", 43 "joins", 44 "where", 45 "order", 46 "hint", 47} 48 49 50def merge_ctes(expression, leave_tables_isolated=False): 51 scopes = traverse_scope(expression) 52 53 # All places where we select from CTEs. 54 # We key on the CTE scope so we can detect CTES that are selected from multiple times. 55 cte_selections = defaultdict(list) 56 for outer_scope in scopes: 57 for table, inner_scope in outer_scope.selected_sources.values(): 58 if isinstance(inner_scope, Scope) and inner_scope.is_cte: 59 cte_selections[id(inner_scope)].append( 60 ( 61 outer_scope, 62 inner_scope, 63 table, 64 ) 65 ) 66 67 singular_cte_selections = [v[0] for k, v in cte_selections.items() if len(v) == 1] 68 for outer_scope, inner_scope, table in singular_cte_selections: 69 from_or_join = table.find_ancestor(exp.From, exp.Join) 70 if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join): 71 alias = table.alias_or_name 72 _rename_inner_sources(outer_scope, inner_scope, alias) 73 _merge_from(outer_scope, inner_scope, table, alias) 74 _merge_expressions(outer_scope, inner_scope, alias) 75 _merge_joins(outer_scope, inner_scope, from_or_join) 76 _merge_where(outer_scope, inner_scope, from_or_join) 77 _merge_order(outer_scope, inner_scope) 78 _merge_hints(outer_scope, inner_scope) 79 _pop_cte(inner_scope) 80 outer_scope.clear_cache() 81 return expression 82 83 84def merge_derived_tables(expression, leave_tables_isolated=False): 85 for outer_scope in traverse_scope(expression): 86 for subquery in outer_scope.derived_tables: 87 from_or_join = subquery.find_ancestor(exp.From, exp.Join) 88 alias = subquery.alias_or_name 89 inner_scope = outer_scope.sources[alias] 90 if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join): 91 _rename_inner_sources(outer_scope, inner_scope, alias) 92 _merge_from(outer_scope, inner_scope, subquery, alias) 93 _merge_expressions(outer_scope, inner_scope, alias) 94 _merge_joins(outer_scope, inner_scope, from_or_join) 95 _merge_where(outer_scope, inner_scope, from_or_join) 96 _merge_order(outer_scope, inner_scope) 97 _merge_hints(outer_scope, inner_scope) 98 outer_scope.clear_cache() 99 return expression 100 101 102def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join): 103 """ 104 Return True if `inner_select` can be merged into outer query. 105 106 Args: 107 outer_scope (Scope) 108 inner_scope (Scope) 109 leave_tables_isolated (bool) 110 from_or_join (exp.From|exp.Join) 111 Returns: 112 bool: True if can be merged 113 """ 114 inner_select = inner_scope.expression.unnest() 115 116 def _is_a_window_expression_in_unmergable_operation(): 117 window_expressions = inner_select.find_all(exp.Window) 118 window_alias_names = {window.parent.alias_or_name for window in window_expressions} 119 inner_select_name = inner_select.parent.alias_or_name 120 unmergable_window_columns = [ 121 column 122 for column in outer_scope.columns 123 if column.find_ancestor( 124 exp.Where, exp.Group, exp.Order, exp.Join, exp.Having, exp.AggFunc 125 ) 126 ] 127 window_expressions_in_unmergable = [ 128 column 129 for column in unmergable_window_columns 130 if column.table == inner_select_name and column.name in window_alias_names 131 ] 132 return any(window_expressions_in_unmergable) 133 134 def _outer_select_joins_on_inner_select_join(): 135 """ 136 All columns from the inner select in the ON clause must be from the first FROM table. 137 138 That is, this can be merged: 139 SELECT * FROM x JOIN (SELECT y.a AS a FROM y JOIN z) AS q ON x.a = q.a 140 ^^^ ^ 141 But this can't: 142 SELECT * FROM x JOIN (SELECT z.a AS a FROM y JOIN z) AS q ON x.a = q.a 143 ^^^ ^ 144 """ 145 if not isinstance(from_or_join, exp.Join): 146 return False 147 148 alias = from_or_join.this.alias_or_name 149 150 on = from_or_join.args.get("on") 151 if not on: 152 return False 153 selections = [c.name for c in on.find_all(exp.Column) if c.table == alias] 154 inner_from = inner_scope.expression.args.get("from") 155 if not inner_from: 156 return False 157 inner_from_table = inner_from.expressions[0].alias_or_name 158 inner_projections = {s.alias_or_name: s for s in inner_scope.selects} 159 return any( 160 col.table != inner_from_table 161 for selection in selections 162 for col in inner_projections[selection].find_all(exp.Column) 163 ) 164 165 return ( 166 isinstance(outer_scope.expression, exp.Select) 167 and isinstance(inner_select, exp.Select) 168 and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS) 169 and inner_select.args.get("from") 170 and not any(e.find(exp.AggFunc, exp.Select) for e in inner_select.expressions) 171 and not (leave_tables_isolated and len(outer_scope.selected_sources) > 1) 172 and not ( 173 isinstance(from_or_join, exp.Join) 174 and inner_select.args.get("where") 175 and from_or_join.side in {"FULL", "LEFT", "RIGHT"} 176 ) 177 and not ( 178 isinstance(from_or_join, exp.From) 179 and inner_select.args.get("where") 180 and any( 181 j.side in {"FULL", "RIGHT"} for j in outer_scope.expression.args.get("joins", []) 182 ) 183 ) 184 and not _outer_select_joins_on_inner_select_join() 185 and not _is_a_window_expression_in_unmergable_operation() 186 ) 187 188 189def _rename_inner_sources(outer_scope, inner_scope, alias): 190 """ 191 Renames any sources in the inner query that conflict with names in the outer query. 192 193 Args: 194 outer_scope (sqlglot.optimizer.scope.Scope) 195 inner_scope (sqlglot.optimizer.scope.Scope) 196 alias (str) 197 """ 198 taken = set(outer_scope.selected_sources) 199 conflicts = taken.intersection(set(inner_scope.selected_sources)) 200 conflicts -= {alias} 201 202 for conflict in conflicts: 203 new_name = find_new_name(taken, conflict) 204 205 source, _ = inner_scope.selected_sources[conflict] 206 new_alias = exp.to_identifier(new_name) 207 208 if isinstance(source, exp.Subquery): 209 source.set("alias", exp.TableAlias(this=new_alias)) 210 elif isinstance(source, exp.Table) and source.alias: 211 source.set("alias", new_alias) 212 elif isinstance(source, exp.Table): 213 source.replace(exp.alias_(source.copy(), new_alias)) 214 215 for column in inner_scope.source_columns(conflict): 216 column.set("table", exp.to_identifier(new_name)) 217 218 inner_scope.rename_source(conflict, new_name) 219 220 221def _merge_from(outer_scope, inner_scope, node_to_replace, alias): 222 """ 223 Merge FROM clause of inner query into outer query. 224 225 Args: 226 outer_scope (sqlglot.optimizer.scope.Scope) 227 inner_scope (sqlglot.optimizer.scope.Scope) 228 node_to_replace (exp.Subquery|exp.Table) 229 alias (str) 230 """ 231 new_subquery = inner_scope.expression.args.get("from").expressions[0] 232 node_to_replace.replace(new_subquery) 233 for join_hint in outer_scope.join_hints: 234 tables = join_hint.find_all(exp.Table) 235 for table in tables: 236 if table.alias_or_name == node_to_replace.alias_or_name: 237 table.set("this", exp.to_identifier(new_subquery.alias_or_name)) 238 outer_scope.remove_source(alias) 239 outer_scope.add_source( 240 new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name] 241 ) 242 243 244def _merge_joins(outer_scope, inner_scope, from_or_join): 245 """ 246 Merge JOIN clauses of inner query into outer query. 247 248 Args: 249 outer_scope (sqlglot.optimizer.scope.Scope) 250 inner_scope (sqlglot.optimizer.scope.Scope) 251 from_or_join (exp.From|exp.Join) 252 """ 253 254 new_joins = [] 255 comma_joins = inner_scope.expression.args.get("from").expressions[1:] 256 for subquery in comma_joins: 257 new_joins.append(exp.Join(this=subquery, kind="CROSS")) 258 outer_scope.add_source(subquery.alias_or_name, inner_scope.sources[subquery.alias_or_name]) 259 260 joins = inner_scope.expression.args.get("joins") or [] 261 for join in joins: 262 new_joins.append(join) 263 outer_scope.add_source(join.alias_or_name, inner_scope.sources[join.alias_or_name]) 264 265 if new_joins: 266 outer_joins = outer_scope.expression.args.get("joins", []) 267 268 # Maintain the join order 269 if isinstance(from_or_join, exp.From): 270 position = 0 271 else: 272 position = outer_joins.index(from_or_join) + 1 273 outer_joins[position:position] = new_joins 274 275 outer_scope.expression.set("joins", outer_joins) 276 277 278def _merge_expressions(outer_scope, inner_scope, alias): 279 """ 280 Merge projections of inner query into outer query. 281 282 Args: 283 outer_scope (sqlglot.optimizer.scope.Scope) 284 inner_scope (sqlglot.optimizer.scope.Scope) 285 alias (str) 286 """ 287 # Collect all columns that reference the alias of the inner query 288 outer_columns = defaultdict(list) 289 for column in outer_scope.columns: 290 if column.table == alias: 291 outer_columns[column.name].append(column) 292 293 # Replace columns with the projection expression in the inner query 294 for expression in inner_scope.expression.expressions: 295 projection_name = expression.alias_or_name 296 if not projection_name: 297 continue 298 columns_to_replace = outer_columns.get(projection_name, []) 299 for column in columns_to_replace: 300 column.replace(expression.unalias().copy()) 301 302 303def _merge_where(outer_scope, inner_scope, from_or_join): 304 """ 305 Merge WHERE clause of inner query into outer query. 306 307 Args: 308 outer_scope (sqlglot.optimizer.scope.Scope) 309 inner_scope (sqlglot.optimizer.scope.Scope) 310 from_or_join (exp.From|exp.Join) 311 """ 312 where = inner_scope.expression.args.get("where") 313 if not where or not where.this: 314 return 315 316 expression = outer_scope.expression 317 318 if isinstance(from_or_join, exp.Join): 319 # Merge predicates from an outer join to the ON clause 320 # if it only has columns that are already joined 321 from_ = expression.args.get("from") 322 sources = {table.alias_or_name for table in from_.expressions} if from_ else {} 323 324 for join in expression.args["joins"]: 325 source = join.alias_or_name 326 sources.add(source) 327 if source == from_or_join.alias_or_name: 328 break 329 330 if set(exp.column_table_names(where.this)) <= sources: 331 from_or_join.on(where.this, copy=False) 332 from_or_join.set("on", from_or_join.args.get("on")) 333 return 334 335 expression.where(where.this, copy=False) 336 expression.set("where", expression.args.get("where")) 337 338 339def _merge_order(outer_scope, inner_scope): 340 """ 341 Merge ORDER clause of inner query into outer query. 342 343 Args: 344 outer_scope (sqlglot.optimizer.scope.Scope) 345 inner_scope (sqlglot.optimizer.scope.Scope) 346 """ 347 if ( 348 any( 349 outer_scope.expression.args.get(arg) for arg in ["group", "distinct", "having", "order"] 350 ) 351 or len(outer_scope.selected_sources) != 1 352 or any(expression.find(exp.AggFunc) for expression in outer_scope.expression.expressions) 353 ): 354 return 355 356 outer_scope.expression.set("order", inner_scope.expression.args.get("order")) 357 358 359def _merge_hints(outer_scope, inner_scope): 360 inner_scope_hint = inner_scope.expression.args.get("hint") 361 if not inner_scope_hint: 362 return 363 outer_scope_hint = outer_scope.expression.args.get("hint") 364 if outer_scope_hint: 365 for hint_expression in inner_scope_hint.expressions: 366 outer_scope_hint.append("expressions", hint_expression) 367 else: 368 outer_scope.expression.set("hint", inner_scope_hint) 369 370 371def _pop_cte(inner_scope): 372 """ 373 Remove CTE from the AST. 374 375 Args: 376 inner_scope (sqlglot.optimizer.scope.Scope) 377 """ 378 cte = inner_scope.expression.parent 379 with_ = cte.parent 380 if len(with_.expressions) == 1: 381 with_.pop() 382 else: 383 cte.pop()
def
merge_subqueries(expression, leave_tables_isolated=False):
9def merge_subqueries(expression, leave_tables_isolated=False): 10 """ 11 Rewrite sqlglot AST to merge derived tables into the outer query. 12 13 This also merges CTEs if they are selected from only once. 14 15 Example: 16 >>> import sqlglot 17 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) JOIN y") 18 >>> merge_subqueries(expression).sql() 19 'SELECT x.a FROM x JOIN y' 20 21 If `leave_tables_isolated` is True, this will not merge inner queries into outer 22 queries if it would result in multiple table selects in a single query: 23 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) JOIN y") 24 >>> merge_subqueries(expression, leave_tables_isolated=True).sql() 25 'SELECT a FROM (SELECT x.a FROM x) JOIN y' 26 27 Inspired by https://dev.mysql.com/doc/refman/8.0/en/derived-table-optimization.html 28 29 Args: 30 expression (sqlglot.Expression): expression to optimize 31 leave_tables_isolated (bool): 32 Returns: 33 sqlglot.Expression: optimized expression 34 """ 35 expression = merge_ctes(expression, leave_tables_isolated) 36 expression = merge_derived_tables(expression, leave_tables_isolated) 37 return expression
Rewrite sqlglot AST to merge derived tables into the outer query.
This also merges CTEs if they are selected from only once.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) JOIN y") >>> merge_subqueries(expression).sql() 'SELECT x.a FROM x JOIN y'
If leave_tables_isolated
is True, this will not merge inner queries into outer
queries if it would result in multiple table selects in a single query:
expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) JOIN y") merge_subqueries(expression, leave_tables_isolated=True).sql() 'SELECT a FROM (SELECT x.a FROM x) JOIN y'
Inspired by https://dev.mysql.com/doc/refman/8.0/en/derived-table-optimization.html
Arguments:
- expression (sqlglot.Expression): expression to optimize
- leave_tables_isolated (bool):
Returns:
sqlglot.Expression: optimized expression
def
merge_ctes(expression, leave_tables_isolated=False):
51def merge_ctes(expression, leave_tables_isolated=False): 52 scopes = traverse_scope(expression) 53 54 # All places where we select from CTEs. 55 # We key on the CTE scope so we can detect CTES that are selected from multiple times. 56 cte_selections = defaultdict(list) 57 for outer_scope in scopes: 58 for table, inner_scope in outer_scope.selected_sources.values(): 59 if isinstance(inner_scope, Scope) and inner_scope.is_cte: 60 cte_selections[id(inner_scope)].append( 61 ( 62 outer_scope, 63 inner_scope, 64 table, 65 ) 66 ) 67 68 singular_cte_selections = [v[0] for k, v in cte_selections.items() if len(v) == 1] 69 for outer_scope, inner_scope, table in singular_cte_selections: 70 from_or_join = table.find_ancestor(exp.From, exp.Join) 71 if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join): 72 alias = table.alias_or_name 73 _rename_inner_sources(outer_scope, inner_scope, alias) 74 _merge_from(outer_scope, inner_scope, table, alias) 75 _merge_expressions(outer_scope, inner_scope, alias) 76 _merge_joins(outer_scope, inner_scope, from_or_join) 77 _merge_where(outer_scope, inner_scope, from_or_join) 78 _merge_order(outer_scope, inner_scope) 79 _merge_hints(outer_scope, inner_scope) 80 _pop_cte(inner_scope) 81 outer_scope.clear_cache() 82 return expression
def
merge_derived_tables(expression, leave_tables_isolated=False):
85def merge_derived_tables(expression, leave_tables_isolated=False): 86 for outer_scope in traverse_scope(expression): 87 for subquery in outer_scope.derived_tables: 88 from_or_join = subquery.find_ancestor(exp.From, exp.Join) 89 alias = subquery.alias_or_name 90 inner_scope = outer_scope.sources[alias] 91 if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join): 92 _rename_inner_sources(outer_scope, inner_scope, alias) 93 _merge_from(outer_scope, inner_scope, subquery, alias) 94 _merge_expressions(outer_scope, inner_scope, alias) 95 _merge_joins(outer_scope, inner_scope, from_or_join) 96 _merge_where(outer_scope, inner_scope, from_or_join) 97 _merge_order(outer_scope, inner_scope) 98 _merge_hints(outer_scope, inner_scope) 99 outer_scope.clear_cache() 100 return expression