sqlglot.optimizer.merge_subqueries
1from collections import defaultdict 2 3from sqlglot import expressions as exp 4from sqlglot.helper import find_new_name 5from sqlglot.optimizer.scope import Scope, traverse_scope 6from sqlglot.optimizer.simplify import simplify 7 8 9def merge_subqueries(expression, leave_tables_isolated=False): 10 """ 11 Rewrite sqlglot AST to merge derived tables into the outer query. 12 13 This also merges CTEs if they are selected from only once. 14 15 Example: 16 >>> import sqlglot 17 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) JOIN y") 18 >>> merge_subqueries(expression).sql() 19 'SELECT x.a FROM x JOIN y' 20 21 If `leave_tables_isolated` is True, this will not merge inner queries into outer 22 queries if it would result in multiple table selects in a single query: 23 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) JOIN y") 24 >>> merge_subqueries(expression, leave_tables_isolated=True).sql() 25 'SELECT a FROM (SELECT x.a FROM x) JOIN y' 26 27 Inspired by https://dev.mysql.com/doc/refman/8.0/en/derived-table-optimization.html 28 29 Args: 30 expression (sqlglot.Expression): expression to optimize 31 leave_tables_isolated (bool): 32 Returns: 33 sqlglot.Expression: optimized expression 34 """ 35 expression = merge_ctes(expression, leave_tables_isolated) 36 expression = merge_derived_tables(expression, leave_tables_isolated) 37 return expression 38 39 40# If a derived table has these Select args, it can't be merged 41UNMERGABLE_ARGS = set(exp.Select.arg_types) - { 42 "expressions", 43 "from", 44 "joins", 45 "where", 46 "order", 47 "hint", 48} 49 50 51def merge_ctes(expression, leave_tables_isolated=False): 52 scopes = traverse_scope(expression) 53 54 # All places where we select from CTEs. 55 # We key on the CTE scope so we can detect CTES that are selected from multiple times. 56 cte_selections = defaultdict(list) 57 for outer_scope in scopes: 58 for table, inner_scope in outer_scope.selected_sources.values(): 59 if isinstance(inner_scope, Scope) and inner_scope.is_cte: 60 cte_selections[id(inner_scope)].append( 61 ( 62 outer_scope, 63 inner_scope, 64 table, 65 ) 66 ) 67 68 singular_cte_selections = [v[0] for k, v in cte_selections.items() if len(v) == 1] 69 for outer_scope, inner_scope, table in singular_cte_selections: 70 from_or_join = table.find_ancestor(exp.From, exp.Join) 71 if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join): 72 alias = table.alias_or_name 73 _rename_inner_sources(outer_scope, inner_scope, alias) 74 _merge_from(outer_scope, inner_scope, table, alias) 75 _merge_expressions(outer_scope, inner_scope, alias) 76 _merge_joins(outer_scope, inner_scope, from_or_join) 77 _merge_where(outer_scope, inner_scope, from_or_join) 78 _merge_order(outer_scope, inner_scope) 79 _merge_hints(outer_scope, inner_scope) 80 _pop_cte(inner_scope) 81 outer_scope.clear_cache() 82 return expression 83 84 85def merge_derived_tables(expression, leave_tables_isolated=False): 86 for outer_scope in traverse_scope(expression): 87 for subquery in outer_scope.derived_tables: 88 from_or_join = subquery.find_ancestor(exp.From, exp.Join) 89 alias = subquery.alias_or_name 90 inner_scope = outer_scope.sources[alias] 91 if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join): 92 _rename_inner_sources(outer_scope, inner_scope, alias) 93 _merge_from(outer_scope, inner_scope, subquery, alias) 94 _merge_expressions(outer_scope, inner_scope, alias) 95 _merge_joins(outer_scope, inner_scope, from_or_join) 96 _merge_where(outer_scope, inner_scope, from_or_join) 97 _merge_order(outer_scope, inner_scope) 98 _merge_hints(outer_scope, inner_scope) 99 outer_scope.clear_cache() 100 return expression 101 102 103def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join): 104 """ 105 Return True if `inner_select` can be merged into outer query. 106 107 Args: 108 outer_scope (Scope) 109 inner_scope (Scope) 110 leave_tables_isolated (bool) 111 from_or_join (exp.From|exp.Join) 112 Returns: 113 bool: True if can be merged 114 """ 115 inner_select = inner_scope.expression.unnest() 116 117 def _is_a_window_expression_in_unmergable_operation(): 118 window_expressions = inner_select.find_all(exp.Window) 119 window_alias_names = {window.parent.alias_or_name for window in window_expressions} 120 inner_select_name = inner_select.parent.alias_or_name 121 unmergable_window_columns = [ 122 column 123 for column in outer_scope.columns 124 if column.find_ancestor( 125 exp.Where, exp.Group, exp.Order, exp.Join, exp.Having, exp.AggFunc 126 ) 127 ] 128 window_expressions_in_unmergable = [ 129 column 130 for column in unmergable_window_columns 131 if column.table == inner_select_name and column.name in window_alias_names 132 ] 133 return any(window_expressions_in_unmergable) 134 135 def _outer_select_joins_on_inner_select_join(): 136 """ 137 All columns from the inner select in the ON clause must be from the first FROM table. 138 139 That is, this can be merged: 140 SELECT * FROM x JOIN (SELECT y.a AS a FROM y JOIN z) AS q ON x.a = q.a 141 ^^^ ^ 142 But this can't: 143 SELECT * FROM x JOIN (SELECT z.a AS a FROM y JOIN z) AS q ON x.a = q.a 144 ^^^ ^ 145 """ 146 if not isinstance(from_or_join, exp.Join): 147 return False 148 149 alias = from_or_join.this.alias_or_name 150 151 on = from_or_join.args.get("on") 152 if not on: 153 return False 154 selections = [c.name for c in on.find_all(exp.Column) if c.table == alias] 155 inner_from = inner_scope.expression.args.get("from") 156 if not inner_from: 157 return False 158 inner_from_table = inner_from.expressions[0].alias_or_name 159 inner_projections = {s.alias_or_name: s for s in inner_scope.selects} 160 return any( 161 col.table != inner_from_table 162 for selection in selections 163 for col in inner_projections[selection].find_all(exp.Column) 164 ) 165 166 return ( 167 isinstance(outer_scope.expression, exp.Select) 168 and isinstance(inner_select, exp.Select) 169 and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS) 170 and inner_select.args.get("from") 171 and not any(e.find(exp.AggFunc, exp.Select) for e in inner_select.expressions) 172 and not (leave_tables_isolated and len(outer_scope.selected_sources) > 1) 173 and not ( 174 isinstance(from_or_join, exp.Join) 175 and inner_select.args.get("where") 176 and from_or_join.side in {"FULL", "LEFT", "RIGHT"} 177 ) 178 and not ( 179 isinstance(from_or_join, exp.From) 180 and inner_select.args.get("where") 181 and any( 182 j.side in {"FULL", "RIGHT"} for j in outer_scope.expression.args.get("joins", []) 183 ) 184 ) 185 and not _outer_select_joins_on_inner_select_join() 186 and not _is_a_window_expression_in_unmergable_operation() 187 ) 188 189 190def _rename_inner_sources(outer_scope, inner_scope, alias): 191 """ 192 Renames any sources in the inner query that conflict with names in the outer query. 193 194 Args: 195 outer_scope (sqlglot.optimizer.scope.Scope) 196 inner_scope (sqlglot.optimizer.scope.Scope) 197 alias (str) 198 """ 199 taken = set(outer_scope.selected_sources) 200 conflicts = taken.intersection(set(inner_scope.selected_sources)) 201 conflicts -= {alias} 202 203 for conflict in conflicts: 204 new_name = find_new_name(taken, conflict) 205 206 source, _ = inner_scope.selected_sources[conflict] 207 new_alias = exp.to_identifier(new_name) 208 209 if isinstance(source, exp.Subquery): 210 source.set("alias", exp.TableAlias(this=new_alias)) 211 elif isinstance(source, exp.Table) and source.alias: 212 source.set("alias", new_alias) 213 elif isinstance(source, exp.Table): 214 source.replace(exp.alias_(source.copy(), new_alias)) 215 216 for column in inner_scope.source_columns(conflict): 217 column.set("table", exp.to_identifier(new_name)) 218 219 inner_scope.rename_source(conflict, new_name) 220 221 222def _merge_from(outer_scope, inner_scope, node_to_replace, alias): 223 """ 224 Merge FROM clause of inner query into outer query. 225 226 Args: 227 outer_scope (sqlglot.optimizer.scope.Scope) 228 inner_scope (sqlglot.optimizer.scope.Scope) 229 node_to_replace (exp.Subquery|exp.Table) 230 alias (str) 231 """ 232 new_subquery = inner_scope.expression.args.get("from").expressions[0] 233 node_to_replace.replace(new_subquery) 234 for join_hint in outer_scope.join_hints: 235 tables = join_hint.find_all(exp.Table) 236 for table in tables: 237 if table.alias_or_name == node_to_replace.alias_or_name: 238 table.set("this", exp.to_identifier(new_subquery.alias_or_name)) 239 outer_scope.remove_source(alias) 240 outer_scope.add_source( 241 new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name] 242 ) 243 244 245def _merge_joins(outer_scope, inner_scope, from_or_join): 246 """ 247 Merge JOIN clauses of inner query into outer query. 248 249 Args: 250 outer_scope (sqlglot.optimizer.scope.Scope) 251 inner_scope (sqlglot.optimizer.scope.Scope) 252 from_or_join (exp.From|exp.Join) 253 """ 254 255 new_joins = [] 256 comma_joins = inner_scope.expression.args.get("from").expressions[1:] 257 for subquery in comma_joins: 258 new_joins.append(exp.Join(this=subquery, kind="CROSS")) 259 outer_scope.add_source(subquery.alias_or_name, inner_scope.sources[subquery.alias_or_name]) 260 261 joins = inner_scope.expression.args.get("joins") or [] 262 for join in joins: 263 new_joins.append(join) 264 outer_scope.add_source(join.alias_or_name, inner_scope.sources[join.alias_or_name]) 265 266 if new_joins: 267 outer_joins = outer_scope.expression.args.get("joins", []) 268 269 # Maintain the join order 270 if isinstance(from_or_join, exp.From): 271 position = 0 272 else: 273 position = outer_joins.index(from_or_join) + 1 274 outer_joins[position:position] = new_joins 275 276 outer_scope.expression.set("joins", outer_joins) 277 278 279def _merge_expressions(outer_scope, inner_scope, alias): 280 """ 281 Merge projections of inner query into outer query. 282 283 Args: 284 outer_scope (sqlglot.optimizer.scope.Scope) 285 inner_scope (sqlglot.optimizer.scope.Scope) 286 alias (str) 287 """ 288 # Collect all columns that reference the alias of the inner query 289 outer_columns = defaultdict(list) 290 for column in outer_scope.columns: 291 if column.table == alias: 292 outer_columns[column.name].append(column) 293 294 # Replace columns with the projection expression in the inner query 295 for expression in inner_scope.expression.expressions: 296 projection_name = expression.alias_or_name 297 if not projection_name: 298 continue 299 columns_to_replace = outer_columns.get(projection_name, []) 300 for column in columns_to_replace: 301 column.replace(expression.unalias().copy()) 302 303 304def _merge_where(outer_scope, inner_scope, from_or_join): 305 """ 306 Merge WHERE clause of inner query into outer query. 307 308 Args: 309 outer_scope (sqlglot.optimizer.scope.Scope) 310 inner_scope (sqlglot.optimizer.scope.Scope) 311 from_or_join (exp.From|exp.Join) 312 """ 313 where = inner_scope.expression.args.get("where") 314 if not where or not where.this: 315 return 316 317 if isinstance(from_or_join, exp.Join): 318 # Merge predicates from an outer join to the ON clause 319 from_or_join.on(where.this, copy=False) 320 from_or_join.set("on", simplify(from_or_join.args.get("on"))) 321 else: 322 outer_scope.expression.where(where.this, copy=False) 323 outer_scope.expression.set("where", simplify(outer_scope.expression.args.get("where"))) 324 325 326def _merge_order(outer_scope, inner_scope): 327 """ 328 Merge ORDER clause of inner query into outer query. 329 330 Args: 331 outer_scope (sqlglot.optimizer.scope.Scope) 332 inner_scope (sqlglot.optimizer.scope.Scope) 333 """ 334 if ( 335 any( 336 outer_scope.expression.args.get(arg) for arg in ["group", "distinct", "having", "order"] 337 ) 338 or len(outer_scope.selected_sources) != 1 339 or any(expression.find(exp.AggFunc) for expression in outer_scope.expression.expressions) 340 ): 341 return 342 343 outer_scope.expression.set("order", inner_scope.expression.args.get("order")) 344 345 346def _merge_hints(outer_scope, inner_scope): 347 inner_scope_hint = inner_scope.expression.args.get("hint") 348 if not inner_scope_hint: 349 return 350 outer_scope_hint = outer_scope.expression.args.get("hint") 351 if outer_scope_hint: 352 for hint_expression in inner_scope_hint.expressions: 353 outer_scope_hint.append("expressions", hint_expression) 354 else: 355 outer_scope.expression.set("hint", inner_scope_hint) 356 357 358def _pop_cte(inner_scope): 359 """ 360 Remove CTE from the AST. 361 362 Args: 363 inner_scope (sqlglot.optimizer.scope.Scope) 364 """ 365 cte = inner_scope.expression.parent 366 with_ = cte.parent 367 if len(with_.expressions) == 1: 368 with_.pop() 369 else: 370 cte.pop()
def
merge_subqueries(expression, leave_tables_isolated=False):
10def merge_subqueries(expression, leave_tables_isolated=False): 11 """ 12 Rewrite sqlglot AST to merge derived tables into the outer query. 13 14 This also merges CTEs if they are selected from only once. 15 16 Example: 17 >>> import sqlglot 18 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) JOIN y") 19 >>> merge_subqueries(expression).sql() 20 'SELECT x.a FROM x JOIN y' 21 22 If `leave_tables_isolated` is True, this will not merge inner queries into outer 23 queries if it would result in multiple table selects in a single query: 24 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) JOIN y") 25 >>> merge_subqueries(expression, leave_tables_isolated=True).sql() 26 'SELECT a FROM (SELECT x.a FROM x) JOIN y' 27 28 Inspired by https://dev.mysql.com/doc/refman/8.0/en/derived-table-optimization.html 29 30 Args: 31 expression (sqlglot.Expression): expression to optimize 32 leave_tables_isolated (bool): 33 Returns: 34 sqlglot.Expression: optimized expression 35 """ 36 expression = merge_ctes(expression, leave_tables_isolated) 37 expression = merge_derived_tables(expression, leave_tables_isolated) 38 return expression
Rewrite sqlglot AST to merge derived tables into the outer query.
This also merges CTEs if they are selected from only once.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) JOIN y") >>> merge_subqueries(expression).sql() 'SELECT x.a FROM x JOIN y'
If leave_tables_isolated
is True, this will not merge inner queries into outer
queries if it would result in multiple table selects in a single query:
expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) JOIN y") merge_subqueries(expression, leave_tables_isolated=True).sql() 'SELECT a FROM (SELECT x.a FROM x) JOIN y'
Inspired by https://dev.mysql.com/doc/refman/8.0/en/derived-table-optimization.html
Arguments:
- expression (sqlglot.Expression): expression to optimize
- leave_tables_isolated (bool):
Returns:
sqlglot.Expression: optimized expression
def
merge_ctes(expression, leave_tables_isolated=False):
52def merge_ctes(expression, leave_tables_isolated=False): 53 scopes = traverse_scope(expression) 54 55 # All places where we select from CTEs. 56 # We key on the CTE scope so we can detect CTES that are selected from multiple times. 57 cte_selections = defaultdict(list) 58 for outer_scope in scopes: 59 for table, inner_scope in outer_scope.selected_sources.values(): 60 if isinstance(inner_scope, Scope) and inner_scope.is_cte: 61 cte_selections[id(inner_scope)].append( 62 ( 63 outer_scope, 64 inner_scope, 65 table, 66 ) 67 ) 68 69 singular_cte_selections = [v[0] for k, v in cte_selections.items() if len(v) == 1] 70 for outer_scope, inner_scope, table in singular_cte_selections: 71 from_or_join = table.find_ancestor(exp.From, exp.Join) 72 if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join): 73 alias = table.alias_or_name 74 _rename_inner_sources(outer_scope, inner_scope, alias) 75 _merge_from(outer_scope, inner_scope, table, alias) 76 _merge_expressions(outer_scope, inner_scope, alias) 77 _merge_joins(outer_scope, inner_scope, from_or_join) 78 _merge_where(outer_scope, inner_scope, from_or_join) 79 _merge_order(outer_scope, inner_scope) 80 _merge_hints(outer_scope, inner_scope) 81 _pop_cte(inner_scope) 82 outer_scope.clear_cache() 83 return expression
def
merge_derived_tables(expression, leave_tables_isolated=False):
86def merge_derived_tables(expression, leave_tables_isolated=False): 87 for outer_scope in traverse_scope(expression): 88 for subquery in outer_scope.derived_tables: 89 from_or_join = subquery.find_ancestor(exp.From, exp.Join) 90 alias = subquery.alias_or_name 91 inner_scope = outer_scope.sources[alias] 92 if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join): 93 _rename_inner_sources(outer_scope, inner_scope, alias) 94 _merge_from(outer_scope, inner_scope, subquery, alias) 95 _merge_expressions(outer_scope, inner_scope, alias) 96 _merge_joins(outer_scope, inner_scope, from_or_join) 97 _merge_where(outer_scope, inner_scope, from_or_join) 98 _merge_order(outer_scope, inner_scope) 99 _merge_hints(outer_scope, inner_scope) 100 outer_scope.clear_cache() 101 return expression