summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/merge_subqueries.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer/merge_subqueries.py')
-rw-r--r--sqlglot/optimizer/merge_subqueries.py287
1 files changed, 287 insertions, 0 deletions
diff --git a/sqlglot/optimizer/merge_subqueries.py b/sqlglot/optimizer/merge_subqueries.py
new file mode 100644
index 0000000..9d966b7
--- /dev/null
+++ b/sqlglot/optimizer/merge_subqueries.py
@@ -0,0 +1,287 @@
+from collections import defaultdict
+
+from sqlglot import expressions as exp
+from sqlglot.helper import find_new_name
+from sqlglot.optimizer.scope import Scope, traverse_scope
+from sqlglot.optimizer.simplify import simplify
+
+
+def merge_subqueries(expression, leave_tables_isolated=False):
+ """
+ Rewrite sqlglot AST to merge derived tables into the outer query.
+
+ This also merges CTEs if they are selected from only once.
+
+ Example:
+ >>> import sqlglot
+ >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) JOIN y")
+ >>> merge_subqueries(expression).sql()
+ 'SELECT x.a FROM x JOIN y'
+
+ If `leave_tables_isolated` is True, this will not merge inner queries into outer
+ queries if it would result in multiple table selects in a single query:
+ >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) JOIN y")
+ >>> merge_subqueries(expression, leave_tables_isolated=True).sql()
+ 'SELECT a FROM (SELECT x.a FROM x) JOIN y'
+
+ Inspired by https://dev.mysql.com/doc/refman/8.0/en/derived-table-optimization.html
+
+ Args:
+ expression (sqlglot.Expression): expression to optimize
+ leave_tables_isolated (bool):
+ Returns:
+ sqlglot.Expression: optimized expression
+ """
+ merge_ctes(expression, leave_tables_isolated)
+ merge_derived_tables(expression, leave_tables_isolated)
+ return expression
+
+
+# If a derived table has these Select args, it can't be merged
+UNMERGABLE_ARGS = set(exp.Select.arg_types) - {
+ "expressions",
+ "from",
+ "joins",
+ "where",
+ "order",
+}
+
+
+def merge_ctes(expression, leave_tables_isolated=False):
+ scopes = traverse_scope(expression)
+
+ # All places where we select from CTEs.
+ # We key on the CTE scope so we can detect CTES that are selected from multiple times.
+ cte_selections = defaultdict(list)
+ for outer_scope in scopes:
+ for table, inner_scope in outer_scope.selected_sources.values():
+ if isinstance(inner_scope, Scope) and inner_scope.is_cte:
+ cte_selections[id(inner_scope)].append(
+ (
+ outer_scope,
+ inner_scope,
+ table,
+ )
+ )
+
+ singular_cte_selections = [v[0] for k, v in cte_selections.items() if len(v) == 1]
+ for outer_scope, inner_scope, table in singular_cte_selections:
+ inner_select = inner_scope.expression.unnest()
+ if _mergeable(outer_scope, inner_select, leave_tables_isolated):
+ from_or_join = table.find_ancestor(exp.From, exp.Join)
+
+ node_to_replace = table
+ if isinstance(node_to_replace.parent, exp.Alias):
+ node_to_replace = node_to_replace.parent
+ alias = node_to_replace.alias
+ else:
+ alias = table.name
+
+ _rename_inner_sources(outer_scope, inner_scope, alias)
+ _merge_from(outer_scope, inner_scope, node_to_replace, alias)
+ _merge_joins(outer_scope, inner_scope, from_or_join)
+ _merge_expressions(outer_scope, inner_scope, alias)
+ _merge_where(outer_scope, inner_scope, from_or_join)
+ _merge_order(outer_scope, inner_scope)
+ _pop_cte(inner_scope)
+
+
+def merge_derived_tables(expression, leave_tables_isolated=False):
+ for outer_scope in traverse_scope(expression):
+ for subquery in outer_scope.derived_tables:
+ inner_select = subquery.unnest()
+ if _mergeable(outer_scope, inner_select, leave_tables_isolated):
+ alias = subquery.alias_or_name
+ from_or_join = subquery.find_ancestor(exp.From, exp.Join)
+ inner_scope = outer_scope.sources[alias]
+
+ _rename_inner_sources(outer_scope, inner_scope, alias)
+ _merge_from(outer_scope, inner_scope, subquery, alias)
+ _merge_joins(outer_scope, inner_scope, from_or_join)
+ _merge_expressions(outer_scope, inner_scope, alias)
+ _merge_where(outer_scope, inner_scope, from_or_join)
+ _merge_order(outer_scope, inner_scope)
+
+
+def _mergeable(outer_scope, inner_select, leave_tables_isolated):
+ """
+ Return True if `inner_select` can be merged into outer query.
+
+ Args:
+ outer_scope (Scope)
+ inner_select (exp.Select)
+ leave_tables_isolated (bool)
+ Returns:
+ bool: True if can be merged
+ """
+ return (
+ isinstance(outer_scope.expression, exp.Select)
+ and isinstance(inner_select, exp.Select)
+ and isinstance(inner_select, exp.Select)
+ and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS)
+ and inner_select.args.get("from")
+ and not any(e.find(exp.AggFunc, exp.Select) for e in inner_select.expressions)
+ and not (leave_tables_isolated and len(outer_scope.selected_sources) > 1)
+ )
+
+
+def _rename_inner_sources(outer_scope, inner_scope, alias):
+ """
+ Renames any sources in the inner query that conflict with names in the outer query.
+
+ Args:
+ outer_scope (sqlglot.optimizer.scope.Scope)
+ inner_scope (sqlglot.optimizer.scope.Scope)
+ alias (str)
+ """
+ taken = set(outer_scope.selected_sources)
+ conflicts = taken.intersection(set(inner_scope.selected_sources))
+ conflicts = conflicts - {alias}
+
+ for conflict in conflicts:
+ new_name = find_new_name(taken, conflict)
+
+ source, _ = inner_scope.selected_sources[conflict]
+ new_alias = exp.to_identifier(new_name)
+
+ if isinstance(source, exp.Subquery):
+ source.set("alias", exp.TableAlias(this=new_alias))
+ elif isinstance(source, exp.Table) and isinstance(source.parent, exp.Alias):
+ source.parent.set("alias", new_alias)
+ elif isinstance(source, exp.Table):
+ source.replace(exp.alias_(source.copy(), new_alias))
+
+ for column in inner_scope.source_columns(conflict):
+ column.set("table", exp.to_identifier(new_name))
+
+ inner_scope.rename_source(conflict, new_name)
+
+
+def _merge_from(outer_scope, inner_scope, node_to_replace, alias):
+ """
+ Merge FROM clause of inner query into outer query.
+
+ Args:
+ outer_scope (sqlglot.optimizer.scope.Scope)
+ inner_scope (sqlglot.optimizer.scope.Scope)
+ node_to_replace (exp.Subquery|exp.Table)
+ alias (str)
+ """
+ new_subquery = inner_scope.expression.args.get("from").expressions[0]
+ node_to_replace.replace(new_subquery)
+ outer_scope.remove_source(alias)
+ outer_scope.add_source(new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name])
+
+
+def _merge_joins(outer_scope, inner_scope, from_or_join):
+ """
+ Merge JOIN clauses of inner query into outer query.
+
+ Args:
+ outer_scope (sqlglot.optimizer.scope.Scope)
+ inner_scope (sqlglot.optimizer.scope.Scope)
+ from_or_join (exp.From|exp.Join)
+ """
+
+ new_joins = []
+ comma_joins = inner_scope.expression.args.get("from").expressions[1:]
+ for subquery in comma_joins:
+ new_joins.append(exp.Join(this=subquery, kind="CROSS"))
+ outer_scope.add_source(subquery.alias_or_name, inner_scope.sources[subquery.alias_or_name])
+
+ joins = inner_scope.expression.args.get("joins") or []
+ for join in joins:
+ new_joins.append(join)
+ outer_scope.add_source(join.alias_or_name, inner_scope.sources[join.alias_or_name])
+
+ if new_joins:
+ outer_joins = outer_scope.expression.args.get("joins", [])
+
+ # Maintain the join order
+ if isinstance(from_or_join, exp.From):
+ position = 0
+ else:
+ position = outer_joins.index(from_or_join) + 1
+ outer_joins[position:position] = new_joins
+
+ outer_scope.expression.set("joins", outer_joins)
+
+
+def _merge_expressions(outer_scope, inner_scope, alias):
+ """
+ Merge projections of inner query into outer query.
+
+ Args:
+ outer_scope (sqlglot.optimizer.scope.Scope)
+ inner_scope (sqlglot.optimizer.scope.Scope)
+ alias (str)
+ """
+ # Collect all columns that reference the alias of the inner query
+ outer_columns = defaultdict(list)
+ for column in outer_scope.columns:
+ if column.table == alias:
+ outer_columns[column.name].append(column)
+
+ # Replace columns with the projection expression in the inner query
+ for expression in inner_scope.expression.expressions:
+ projection_name = expression.alias_or_name
+ if not projection_name:
+ continue
+ columns_to_replace = outer_columns.get(projection_name, [])
+ for column in columns_to_replace:
+ column.replace(expression.unalias())
+
+
+def _merge_where(outer_scope, inner_scope, from_or_join):
+ """
+ Merge WHERE clause of inner query into outer query.
+
+ Args:
+ outer_scope (sqlglot.optimizer.scope.Scope)
+ inner_scope (sqlglot.optimizer.scope.Scope)
+ from_or_join (exp.From|exp.Join)
+ """
+ where = inner_scope.expression.args.get("where")
+ if not where or not where.this:
+ return
+
+ if isinstance(from_or_join, exp.Join):
+ # Merge predicates from an outer join to the ON clause
+ from_or_join.on(where.this, copy=False)
+ from_or_join.set("on", simplify(from_or_join.args.get("on")))
+ else:
+ outer_scope.expression.where(where.this, copy=False)
+ outer_scope.expression.set("where", simplify(outer_scope.expression.args.get("where")))
+
+
+def _merge_order(outer_scope, inner_scope):
+ """
+ Merge ORDER clause of inner query into outer query.
+
+ Args:
+ outer_scope (sqlglot.optimizer.scope.Scope)
+ inner_scope (sqlglot.optimizer.scope.Scope)
+ """
+ if (
+ any(outer_scope.expression.args.get(arg) for arg in ["group", "distinct", "having", "order"])
+ or len(outer_scope.selected_sources) != 1
+ or any(expression.find(exp.AggFunc) for expression in outer_scope.expression.expressions)
+ ):
+ return
+
+ outer_scope.expression.set("order", inner_scope.expression.args.get("order"))
+
+
+def _pop_cte(inner_scope):
+ """
+ Remove CTE from the AST.
+
+ Args:
+ inner_scope (sqlglot.optimizer.scope.Scope)
+ """
+ cte = inner_scope.expression.parent
+ with_ = cte.parent
+ if len(with_.expressions) == 1:
+ with_.pop()
+ else:
+ cte.pop()