diff options
Diffstat (limited to '')
-rw-r--r-- | sqlglot/optimizer/merge_subqueries.py (renamed from sqlglot/optimizer/merge_derived_tables.py) | 149 |
1 files changed, 102 insertions, 47 deletions
diff --git a/sqlglot/optimizer/merge_derived_tables.py b/sqlglot/optimizer/merge_subqueries.py index 8b161fb..9d966b7 100644 --- a/sqlglot/optimizer/merge_derived_tables.py +++ b/sqlglot/optimizer/merge_subqueries.py @@ -1,72 +1,127 @@ from collections import defaultdict from sqlglot import expressions as exp -from sqlglot.optimizer.scope import traverse_scope +from sqlglot.helper import find_new_name +from sqlglot.optimizer.scope import Scope, traverse_scope from sqlglot.optimizer.simplify import simplify -def merge_derived_tables(expression): +def merge_subqueries(expression, leave_tables_isolated=False): """ Rewrite sqlglot AST to merge derived tables into the outer query. + This also merges CTEs if they are selected from only once. + Example: >>> import sqlglot - >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x)") - >>> merge_derived_tables(expression).sql() - 'SELECT x.a FROM x' + >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) JOIN y") + >>> merge_subqueries(expression).sql() + 'SELECT x.a FROM x JOIN y' + + If `leave_tables_isolated` is True, this will not merge inner queries into outer + queries if it would result in multiple table selects in a single query: + >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) JOIN y") + >>> merge_subqueries(expression, leave_tables_isolated=True).sql() + 'SELECT a FROM (SELECT x.a FROM x) JOIN y' Inspired by https://dev.mysql.com/doc/refman/8.0/en/derived-table-optimization.html Args: expression (sqlglot.Expression): expression to optimize + leave_tables_isolated (bool): Returns: sqlglot.Expression: optimized expression """ + merge_ctes(expression, leave_tables_isolated) + merge_derived_tables(expression, leave_tables_isolated) + return expression + + +# If a derived table has these Select args, it can't be merged +UNMERGABLE_ARGS = set(exp.Select.arg_types) - { + "expressions", + "from", + "joins", + "where", + "order", +} + + +def merge_ctes(expression, leave_tables_isolated=False): + scopes = traverse_scope(expression) + + # All places where we select from CTEs. + # We key on the CTE scope so we can detect CTES that are selected from multiple times. + cte_selections = defaultdict(list) + for outer_scope in scopes: + for table, inner_scope in outer_scope.selected_sources.values(): + if isinstance(inner_scope, Scope) and inner_scope.is_cte: + cte_selections[id(inner_scope)].append( + ( + outer_scope, + inner_scope, + table, + ) + ) + + singular_cte_selections = [v[0] for k, v in cte_selections.items() if len(v) == 1] + for outer_scope, inner_scope, table in singular_cte_selections: + inner_select = inner_scope.expression.unnest() + if _mergeable(outer_scope, inner_select, leave_tables_isolated): + from_or_join = table.find_ancestor(exp.From, exp.Join) + + node_to_replace = table + if isinstance(node_to_replace.parent, exp.Alias): + node_to_replace = node_to_replace.parent + alias = node_to_replace.alias + else: + alias = table.name + + _rename_inner_sources(outer_scope, inner_scope, alias) + _merge_from(outer_scope, inner_scope, node_to_replace, alias) + _merge_joins(outer_scope, inner_scope, from_or_join) + _merge_expressions(outer_scope, inner_scope, alias) + _merge_where(outer_scope, inner_scope, from_or_join) + _merge_order(outer_scope, inner_scope) + _pop_cte(inner_scope) + + +def merge_derived_tables(expression, leave_tables_isolated=False): for outer_scope in traverse_scope(expression): for subquery in outer_scope.derived_tables: inner_select = subquery.unnest() - if ( - isinstance(outer_scope.expression, exp.Select) - and isinstance(inner_select, exp.Select) - and _mergeable(inner_select) - ): + if _mergeable(outer_scope, inner_select, leave_tables_isolated): alias = subquery.alias_or_name from_or_join = subquery.find_ancestor(exp.From, exp.Join) inner_scope = outer_scope.sources[alias] _rename_inner_sources(outer_scope, inner_scope, alias) - _merge_from(outer_scope, inner_scope, subquery) + _merge_from(outer_scope, inner_scope, subquery, alias) _merge_joins(outer_scope, inner_scope, from_or_join) _merge_expressions(outer_scope, inner_scope, alias) _merge_where(outer_scope, inner_scope, from_or_join) _merge_order(outer_scope, inner_scope) - return expression -# If a derived table has these Select args, it can't be merged -UNMERGABLE_ARGS = set(exp.Select.arg_types) - { - "expressions", - "from", - "joins", - "where", - "order", -} - - -def _mergeable(inner_select): +def _mergeable(outer_scope, inner_select, leave_tables_isolated): """ Return True if `inner_select` can be merged into outer query. Args: + outer_scope (Scope) inner_select (exp.Select) + leave_tables_isolated (bool) Returns: bool: True if can be merged """ return ( - isinstance(inner_select, exp.Select) + isinstance(outer_scope.expression, exp.Select) + and isinstance(inner_select, exp.Select) + and isinstance(inner_select, exp.Select) and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS) and inner_select.args.get("from") and not any(e.find(exp.AggFunc, exp.Select) for e in inner_select.expressions) + and not (leave_tables_isolated and len(outer_scope.selected_sources) > 1) ) @@ -84,7 +139,7 @@ def _rename_inner_sources(outer_scope, inner_scope, alias): conflicts = conflicts - {alias} for conflict in conflicts: - new_name = _find_new_name(taken, conflict) + new_name = find_new_name(taken, conflict) source, _ = inner_scope.selected_sources[conflict] new_alias = exp.to_identifier(new_name) @@ -102,34 +157,19 @@ def _rename_inner_sources(outer_scope, inner_scope, alias): inner_scope.rename_source(conflict, new_name) -def _find_new_name(taken, base): - """ - Searches for a new source name. - - Args: - taken (set[str]): set of taken names - base (str): base name to alter - """ - i = 2 - new = f"{base}_{i}" - while new in taken: - i += 1 - new = f"{base}_{i}" - return new - - -def _merge_from(outer_scope, inner_scope, subquery): +def _merge_from(outer_scope, inner_scope, node_to_replace, alias): """ Merge FROM clause of inner query into outer query. Args: outer_scope (sqlglot.optimizer.scope.Scope) inner_scope (sqlglot.optimizer.scope.Scope) - subquery (exp.Subquery) + node_to_replace (exp.Subquery|exp.Table) + alias (str) """ new_subquery = inner_scope.expression.args.get("from").expressions[0] - subquery.replace(new_subquery) - outer_scope.remove_source(subquery.alias_or_name) + node_to_replace.replace(new_subquery) + outer_scope.remove_source(alias) outer_scope.add_source(new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name]) @@ -176,7 +216,7 @@ def _merge_expressions(outer_scope, inner_scope, alias): inner_scope (sqlglot.optimizer.scope.Scope) alias (str) """ - # Collect all columns that for the alias of the inner query + # Collect all columns that reference the alias of the inner query outer_columns = defaultdict(list) for column in outer_scope.columns: if column.table == alias: @@ -205,7 +245,7 @@ def _merge_where(outer_scope, inner_scope, from_or_join): if not where or not where.this: return - if isinstance(from_or_join, exp.Join) and from_or_join.side: + if isinstance(from_or_join, exp.Join): # Merge predicates from an outer join to the ON clause from_or_join.on(where.this, copy=False) from_or_join.set("on", simplify(from_or_join.args.get("on"))) @@ -230,3 +270,18 @@ def _merge_order(outer_scope, inner_scope): return outer_scope.expression.set("order", inner_scope.expression.args.get("order")) + + +def _pop_cte(inner_scope): + """ + Remove CTE from the AST. + + Args: + inner_scope (sqlglot.optimizer.scope.Scope) + """ + cte = inner_scope.expression.parent + with_ = cte.parent + if len(with_.expressions) == 1: + with_.pop() + else: + cte.pop() |