diff options
Diffstat (limited to 'sqlglot/optimizer/merge_subqueries.py')
-rw-r--r-- | sqlglot/optimizer/merge_subqueries.py | 287 |
1 files changed, 287 insertions, 0 deletions
diff --git a/sqlglot/optimizer/merge_subqueries.py b/sqlglot/optimizer/merge_subqueries.py new file mode 100644 index 0000000..9d966b7 --- /dev/null +++ b/sqlglot/optimizer/merge_subqueries.py @@ -0,0 +1,287 @@ +from collections import defaultdict + +from sqlglot import expressions as exp +from sqlglot.helper import find_new_name +from sqlglot.optimizer.scope import Scope, traverse_scope +from sqlglot.optimizer.simplify import simplify + + +def merge_subqueries(expression, leave_tables_isolated=False): + """ + Rewrite sqlglot AST to merge derived tables into the outer query. + + This also merges CTEs if they are selected from only once. + + Example: + >>> import sqlglot + >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) JOIN y") + >>> merge_subqueries(expression).sql() + 'SELECT x.a FROM x JOIN y' + + If `leave_tables_isolated` is True, this will not merge inner queries into outer + queries if it would result in multiple table selects in a single query: + >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) JOIN y") + >>> merge_subqueries(expression, leave_tables_isolated=True).sql() + 'SELECT a FROM (SELECT x.a FROM x) JOIN y' + + Inspired by https://dev.mysql.com/doc/refman/8.0/en/derived-table-optimization.html + + Args: + expression (sqlglot.Expression): expression to optimize + leave_tables_isolated (bool): + Returns: + sqlglot.Expression: optimized expression + """ + merge_ctes(expression, leave_tables_isolated) + merge_derived_tables(expression, leave_tables_isolated) + return expression + + +# If a derived table has these Select args, it can't be merged +UNMERGABLE_ARGS = set(exp.Select.arg_types) - { + "expressions", + "from", + "joins", + "where", + "order", +} + + +def merge_ctes(expression, leave_tables_isolated=False): + scopes = traverse_scope(expression) + + # All places where we select from CTEs. + # We key on the CTE scope so we can detect CTES that are selected from multiple times. + cte_selections = defaultdict(list) + for outer_scope in scopes: + for table, inner_scope in outer_scope.selected_sources.values(): + if isinstance(inner_scope, Scope) and inner_scope.is_cte: + cte_selections[id(inner_scope)].append( + ( + outer_scope, + inner_scope, + table, + ) + ) + + singular_cte_selections = [v[0] for k, v in cte_selections.items() if len(v) == 1] + for outer_scope, inner_scope, table in singular_cte_selections: + inner_select = inner_scope.expression.unnest() + if _mergeable(outer_scope, inner_select, leave_tables_isolated): + from_or_join = table.find_ancestor(exp.From, exp.Join) + + node_to_replace = table + if isinstance(node_to_replace.parent, exp.Alias): + node_to_replace = node_to_replace.parent + alias = node_to_replace.alias + else: + alias = table.name + + _rename_inner_sources(outer_scope, inner_scope, alias) + _merge_from(outer_scope, inner_scope, node_to_replace, alias) + _merge_joins(outer_scope, inner_scope, from_or_join) + _merge_expressions(outer_scope, inner_scope, alias) + _merge_where(outer_scope, inner_scope, from_or_join) + _merge_order(outer_scope, inner_scope) + _pop_cte(inner_scope) + + +def merge_derived_tables(expression, leave_tables_isolated=False): + for outer_scope in traverse_scope(expression): + for subquery in outer_scope.derived_tables: + inner_select = subquery.unnest() + if _mergeable(outer_scope, inner_select, leave_tables_isolated): + alias = subquery.alias_or_name + from_or_join = subquery.find_ancestor(exp.From, exp.Join) + inner_scope = outer_scope.sources[alias] + + _rename_inner_sources(outer_scope, inner_scope, alias) + _merge_from(outer_scope, inner_scope, subquery, alias) + _merge_joins(outer_scope, inner_scope, from_or_join) + _merge_expressions(outer_scope, inner_scope, alias) + _merge_where(outer_scope, inner_scope, from_or_join) + _merge_order(outer_scope, inner_scope) + + +def _mergeable(outer_scope, inner_select, leave_tables_isolated): + """ + Return True if `inner_select` can be merged into outer query. + + Args: + outer_scope (Scope) + inner_select (exp.Select) + leave_tables_isolated (bool) + Returns: + bool: True if can be merged + """ + return ( + isinstance(outer_scope.expression, exp.Select) + and isinstance(inner_select, exp.Select) + and isinstance(inner_select, exp.Select) + and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS) + and inner_select.args.get("from") + and not any(e.find(exp.AggFunc, exp.Select) for e in inner_select.expressions) + and not (leave_tables_isolated and len(outer_scope.selected_sources) > 1) + ) + + +def _rename_inner_sources(outer_scope, inner_scope, alias): + """ + Renames any sources in the inner query that conflict with names in the outer query. + + Args: + outer_scope (sqlglot.optimizer.scope.Scope) + inner_scope (sqlglot.optimizer.scope.Scope) + alias (str) + """ + taken = set(outer_scope.selected_sources) + conflicts = taken.intersection(set(inner_scope.selected_sources)) + conflicts = conflicts - {alias} + + for conflict in conflicts: + new_name = find_new_name(taken, conflict) + + source, _ = inner_scope.selected_sources[conflict] + new_alias = exp.to_identifier(new_name) + + if isinstance(source, exp.Subquery): + source.set("alias", exp.TableAlias(this=new_alias)) + elif isinstance(source, exp.Table) and isinstance(source.parent, exp.Alias): + source.parent.set("alias", new_alias) + elif isinstance(source, exp.Table): + source.replace(exp.alias_(source.copy(), new_alias)) + + for column in inner_scope.source_columns(conflict): + column.set("table", exp.to_identifier(new_name)) + + inner_scope.rename_source(conflict, new_name) + + +def _merge_from(outer_scope, inner_scope, node_to_replace, alias): + """ + Merge FROM clause of inner query into outer query. + + Args: + outer_scope (sqlglot.optimizer.scope.Scope) + inner_scope (sqlglot.optimizer.scope.Scope) + node_to_replace (exp.Subquery|exp.Table) + alias (str) + """ + new_subquery = inner_scope.expression.args.get("from").expressions[0] + node_to_replace.replace(new_subquery) + outer_scope.remove_source(alias) + outer_scope.add_source(new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name]) + + +def _merge_joins(outer_scope, inner_scope, from_or_join): + """ + Merge JOIN clauses of inner query into outer query. + + Args: + outer_scope (sqlglot.optimizer.scope.Scope) + inner_scope (sqlglot.optimizer.scope.Scope) + from_or_join (exp.From|exp.Join) + """ + + new_joins = [] + comma_joins = inner_scope.expression.args.get("from").expressions[1:] + for subquery in comma_joins: + new_joins.append(exp.Join(this=subquery, kind="CROSS")) + outer_scope.add_source(subquery.alias_or_name, inner_scope.sources[subquery.alias_or_name]) + + joins = inner_scope.expression.args.get("joins") or [] + for join in joins: + new_joins.append(join) + outer_scope.add_source(join.alias_or_name, inner_scope.sources[join.alias_or_name]) + + if new_joins: + outer_joins = outer_scope.expression.args.get("joins", []) + + # Maintain the join order + if isinstance(from_or_join, exp.From): + position = 0 + else: + position = outer_joins.index(from_or_join) + 1 + outer_joins[position:position] = new_joins + + outer_scope.expression.set("joins", outer_joins) + + +def _merge_expressions(outer_scope, inner_scope, alias): + """ + Merge projections of inner query into outer query. + + Args: + outer_scope (sqlglot.optimizer.scope.Scope) + inner_scope (sqlglot.optimizer.scope.Scope) + alias (str) + """ + # Collect all columns that reference the alias of the inner query + outer_columns = defaultdict(list) + for column in outer_scope.columns: + if column.table == alias: + outer_columns[column.name].append(column) + + # Replace columns with the projection expression in the inner query + for expression in inner_scope.expression.expressions: + projection_name = expression.alias_or_name + if not projection_name: + continue + columns_to_replace = outer_columns.get(projection_name, []) + for column in columns_to_replace: + column.replace(expression.unalias()) + + +def _merge_where(outer_scope, inner_scope, from_or_join): + """ + Merge WHERE clause of inner query into outer query. + + Args: + outer_scope (sqlglot.optimizer.scope.Scope) + inner_scope (sqlglot.optimizer.scope.Scope) + from_or_join (exp.From|exp.Join) + """ + where = inner_scope.expression.args.get("where") + if not where or not where.this: + return + + if isinstance(from_or_join, exp.Join): + # Merge predicates from an outer join to the ON clause + from_or_join.on(where.this, copy=False) + from_or_join.set("on", simplify(from_or_join.args.get("on"))) + else: + outer_scope.expression.where(where.this, copy=False) + outer_scope.expression.set("where", simplify(outer_scope.expression.args.get("where"))) + + +def _merge_order(outer_scope, inner_scope): + """ + Merge ORDER clause of inner query into outer query. + + Args: + outer_scope (sqlglot.optimizer.scope.Scope) + inner_scope (sqlglot.optimizer.scope.Scope) + """ + if ( + any(outer_scope.expression.args.get(arg) for arg in ["group", "distinct", "having", "order"]) + or len(outer_scope.selected_sources) != 1 + or any(expression.find(exp.AggFunc) for expression in outer_scope.expression.expressions) + ): + return + + outer_scope.expression.set("order", inner_scope.expression.args.get("order")) + + +def _pop_cte(inner_scope): + """ + Remove CTE from the AST. + + Args: + inner_scope (sqlglot.optimizer.scope.Scope) + """ + cte = inner_scope.expression.parent + with_ = cte.parent + if len(with_.expressions) == 1: + with_.pop() + else: + cte.pop() |