Edit on GitHub

sqlglot.optimizer.merge_subqueries

  1from collections import defaultdict
  2
  3from sqlglot import expressions as exp
  4from sqlglot.helper import find_new_name
  5from sqlglot.optimizer.scope import Scope, traverse_scope
  6
  7
  8def merge_subqueries(expression, leave_tables_isolated=False):
  9    """
 10    Rewrite sqlglot AST to merge derived tables into the outer query.
 11
 12    This also merges CTEs if they are selected from only once.
 13
 14    Example:
 15        >>> import sqlglot
 16        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) JOIN y")
 17        >>> merge_subqueries(expression).sql()
 18        'SELECT x.a FROM x JOIN y'
 19
 20    If `leave_tables_isolated` is True, this will not merge inner queries into outer
 21    queries if it would result in multiple table selects in a single query:
 22        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) JOIN y")
 23        >>> merge_subqueries(expression, leave_tables_isolated=True).sql()
 24        'SELECT a FROM (SELECT x.a FROM x) JOIN y'
 25
 26    Inspired by https://dev.mysql.com/doc/refman/8.0/en/derived-table-optimization.html
 27
 28    Args:
 29        expression (sqlglot.Expression): expression to optimize
 30        leave_tables_isolated (bool):
 31    Returns:
 32        sqlglot.Expression: optimized expression
 33    """
 34    expression = merge_ctes(expression, leave_tables_isolated)
 35    expression = merge_derived_tables(expression, leave_tables_isolated)
 36    return expression
 37
 38
 39# If a derived table has these Select args, it can't be merged
 40UNMERGABLE_ARGS = set(exp.Select.arg_types) - {
 41    "expressions",
 42    "from",
 43    "joins",
 44    "where",
 45    "order",
 46    "hint",
 47}
 48
 49
 50def merge_ctes(expression, leave_tables_isolated=False):
 51    scopes = traverse_scope(expression)
 52
 53    # All places where we select from CTEs.
 54    # We key on the CTE scope so we can detect CTES that are selected from multiple times.
 55    cte_selections = defaultdict(list)
 56    for outer_scope in scopes:
 57        for table, inner_scope in outer_scope.selected_sources.values():
 58            if isinstance(inner_scope, Scope) and inner_scope.is_cte:
 59                cte_selections[id(inner_scope)].append(
 60                    (
 61                        outer_scope,
 62                        inner_scope,
 63                        table,
 64                    )
 65                )
 66
 67    singular_cte_selections = [v[0] for k, v in cte_selections.items() if len(v) == 1]
 68    for outer_scope, inner_scope, table in singular_cte_selections:
 69        from_or_join = table.find_ancestor(exp.From, exp.Join)
 70        if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
 71            alias = table.alias_or_name
 72            _rename_inner_sources(outer_scope, inner_scope, alias)
 73            _merge_from(outer_scope, inner_scope, table, alias)
 74            _merge_expressions(outer_scope, inner_scope, alias)
 75            _merge_joins(outer_scope, inner_scope, from_or_join)
 76            _merge_where(outer_scope, inner_scope, from_or_join)
 77            _merge_order(outer_scope, inner_scope)
 78            _merge_hints(outer_scope, inner_scope)
 79            _pop_cte(inner_scope)
 80            outer_scope.clear_cache()
 81    return expression
 82
 83
 84def merge_derived_tables(expression, leave_tables_isolated=False):
 85    for outer_scope in traverse_scope(expression):
 86        for subquery in outer_scope.derived_tables:
 87            from_or_join = subquery.find_ancestor(exp.From, exp.Join)
 88            alias = subquery.alias_or_name
 89            inner_scope = outer_scope.sources[alias]
 90            if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
 91                _rename_inner_sources(outer_scope, inner_scope, alias)
 92                _merge_from(outer_scope, inner_scope, subquery, alias)
 93                _merge_expressions(outer_scope, inner_scope, alias)
 94                _merge_joins(outer_scope, inner_scope, from_or_join)
 95                _merge_where(outer_scope, inner_scope, from_or_join)
 96                _merge_order(outer_scope, inner_scope)
 97                _merge_hints(outer_scope, inner_scope)
 98                outer_scope.clear_cache()
 99    return expression
100
101
102def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
103    """
104    Return True if `inner_select` can be merged into outer query.
105
106    Args:
107        outer_scope (Scope)
108        inner_scope (Scope)
109        leave_tables_isolated (bool)
110        from_or_join (exp.From|exp.Join)
111    Returns:
112        bool: True if can be merged
113    """
114    inner_select = inner_scope.expression.unnest()
115
116    def _is_a_window_expression_in_unmergable_operation():
117        window_expressions = inner_select.find_all(exp.Window)
118        window_alias_names = {window.parent.alias_or_name for window in window_expressions}
119        inner_select_name = inner_select.parent.alias_or_name
120        unmergable_window_columns = [
121            column
122            for column in outer_scope.columns
123            if column.find_ancestor(
124                exp.Where, exp.Group, exp.Order, exp.Join, exp.Having, exp.AggFunc
125            )
126        ]
127        window_expressions_in_unmergable = [
128            column
129            for column in unmergable_window_columns
130            if column.table == inner_select_name and column.name in window_alias_names
131        ]
132        return any(window_expressions_in_unmergable)
133
134    def _outer_select_joins_on_inner_select_join():
135        """
136        All columns from the inner select in the ON clause must be from the first FROM table.
137
138        That is, this can be merged:
139            SELECT * FROM x JOIN (SELECT y.a AS a FROM y JOIN z) AS q ON x.a = q.a
140                                         ^^^           ^
141        But this can't:
142            SELECT * FROM x JOIN (SELECT z.a AS a FROM y JOIN z) AS q ON x.a = q.a
143                                         ^^^                  ^
144        """
145        if not isinstance(from_or_join, exp.Join):
146            return False
147
148        alias = from_or_join.this.alias_or_name
149
150        on = from_or_join.args.get("on")
151        if not on:
152            return False
153        selections = [c.name for c in on.find_all(exp.Column) if c.table == alias]
154        inner_from = inner_scope.expression.args.get("from")
155        if not inner_from:
156            return False
157        inner_from_table = inner_from.expressions[0].alias_or_name
158        inner_projections = {s.alias_or_name: s for s in inner_scope.selects}
159        return any(
160            col.table != inner_from_table
161            for selection in selections
162            for col in inner_projections[selection].find_all(exp.Column)
163        )
164
165    return (
166        isinstance(outer_scope.expression, exp.Select)
167        and isinstance(inner_select, exp.Select)
168        and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS)
169        and inner_select.args.get("from")
170        and not any(e.find(exp.AggFunc, exp.Select) for e in inner_select.expressions)
171        and not (leave_tables_isolated and len(outer_scope.selected_sources) > 1)
172        and not (
173            isinstance(from_or_join, exp.Join)
174            and inner_select.args.get("where")
175            and from_or_join.side in {"FULL", "LEFT", "RIGHT"}
176        )
177        and not (
178            isinstance(from_or_join, exp.From)
179            and inner_select.args.get("where")
180            and any(
181                j.side in {"FULL", "RIGHT"} for j in outer_scope.expression.args.get("joins", [])
182            )
183        )
184        and not _outer_select_joins_on_inner_select_join()
185        and not _is_a_window_expression_in_unmergable_operation()
186    )
187
188
189def _rename_inner_sources(outer_scope, inner_scope, alias):
190    """
191    Renames any sources in the inner query that conflict with names in the outer query.
192
193    Args:
194        outer_scope (sqlglot.optimizer.scope.Scope)
195        inner_scope (sqlglot.optimizer.scope.Scope)
196        alias (str)
197    """
198    taken = set(outer_scope.selected_sources)
199    conflicts = taken.intersection(set(inner_scope.selected_sources))
200    conflicts -= {alias}
201
202    for conflict in conflicts:
203        new_name = find_new_name(taken, conflict)
204
205        source, _ = inner_scope.selected_sources[conflict]
206        new_alias = exp.to_identifier(new_name)
207
208        if isinstance(source, exp.Subquery):
209            source.set("alias", exp.TableAlias(this=new_alias))
210        elif isinstance(source, exp.Table) and source.alias:
211            source.set("alias", new_alias)
212        elif isinstance(source, exp.Table):
213            source.replace(exp.alias_(source.copy(), new_alias))
214
215        for column in inner_scope.source_columns(conflict):
216            column.set("table", exp.to_identifier(new_name))
217
218        inner_scope.rename_source(conflict, new_name)
219
220
221def _merge_from(outer_scope, inner_scope, node_to_replace, alias):
222    """
223    Merge FROM clause of inner query into outer query.
224
225    Args:
226        outer_scope (sqlglot.optimizer.scope.Scope)
227        inner_scope (sqlglot.optimizer.scope.Scope)
228        node_to_replace (exp.Subquery|exp.Table)
229        alias (str)
230    """
231    new_subquery = inner_scope.expression.args.get("from").expressions[0]
232    node_to_replace.replace(new_subquery)
233    for join_hint in outer_scope.join_hints:
234        tables = join_hint.find_all(exp.Table)
235        for table in tables:
236            if table.alias_or_name == node_to_replace.alias_or_name:
237                table.set("this", exp.to_identifier(new_subquery.alias_or_name))
238    outer_scope.remove_source(alias)
239    outer_scope.add_source(
240        new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name]
241    )
242
243
244def _merge_joins(outer_scope, inner_scope, from_or_join):
245    """
246    Merge JOIN clauses of inner query into outer query.
247
248    Args:
249        outer_scope (sqlglot.optimizer.scope.Scope)
250        inner_scope (sqlglot.optimizer.scope.Scope)
251        from_or_join (exp.From|exp.Join)
252    """
253
254    new_joins = []
255    comma_joins = inner_scope.expression.args.get("from").expressions[1:]
256    for subquery in comma_joins:
257        new_joins.append(exp.Join(this=subquery, kind="CROSS"))
258        outer_scope.add_source(subquery.alias_or_name, inner_scope.sources[subquery.alias_or_name])
259
260    joins = inner_scope.expression.args.get("joins") or []
261    for join in joins:
262        new_joins.append(join)
263        outer_scope.add_source(join.alias_or_name, inner_scope.sources[join.alias_or_name])
264
265    if new_joins:
266        outer_joins = outer_scope.expression.args.get("joins", [])
267
268        # Maintain the join order
269        if isinstance(from_or_join, exp.From):
270            position = 0
271        else:
272            position = outer_joins.index(from_or_join) + 1
273        outer_joins[position:position] = new_joins
274
275        outer_scope.expression.set("joins", outer_joins)
276
277
278def _merge_expressions(outer_scope, inner_scope, alias):
279    """
280    Merge projections of inner query into outer query.
281
282    Args:
283        outer_scope (sqlglot.optimizer.scope.Scope)
284        inner_scope (sqlglot.optimizer.scope.Scope)
285        alias (str)
286    """
287    # Collect all columns that reference the alias of the inner query
288    outer_columns = defaultdict(list)
289    for column in outer_scope.columns:
290        if column.table == alias:
291            outer_columns[column.name].append(column)
292
293    # Replace columns with the projection expression in the inner query
294    for expression in inner_scope.expression.expressions:
295        projection_name = expression.alias_or_name
296        if not projection_name:
297            continue
298        columns_to_replace = outer_columns.get(projection_name, [])
299        for column in columns_to_replace:
300            column.replace(expression.unalias().copy())
301
302
303def _merge_where(outer_scope, inner_scope, from_or_join):
304    """
305    Merge WHERE clause of inner query into outer query.
306
307    Args:
308        outer_scope (sqlglot.optimizer.scope.Scope)
309        inner_scope (sqlglot.optimizer.scope.Scope)
310        from_or_join (exp.From|exp.Join)
311    """
312    where = inner_scope.expression.args.get("where")
313    if not where or not where.this:
314        return
315
316    expression = outer_scope.expression
317
318    if isinstance(from_or_join, exp.Join):
319        # Merge predicates from an outer join to the ON clause
320        # if it only has columns that are already joined
321        from_ = expression.args.get("from")
322        sources = {table.alias_or_name for table in from_.expressions} if from_ else {}
323
324        for join in expression.args["joins"]:
325            source = join.alias_or_name
326            sources.add(source)
327            if source == from_or_join.alias_or_name:
328                break
329
330        if set(exp.column_table_names(where.this)) <= sources:
331            from_or_join.on(where.this, copy=False)
332            from_or_join.set("on", from_or_join.args.get("on"))
333            return
334
335    expression.where(where.this, copy=False)
336    expression.set("where", expression.args.get("where"))
337
338
339def _merge_order(outer_scope, inner_scope):
340    """
341    Merge ORDER clause of inner query into outer query.
342
343    Args:
344        outer_scope (sqlglot.optimizer.scope.Scope)
345        inner_scope (sqlglot.optimizer.scope.Scope)
346    """
347    if (
348        any(
349            outer_scope.expression.args.get(arg) for arg in ["group", "distinct", "having", "order"]
350        )
351        or len(outer_scope.selected_sources) != 1
352        or any(expression.find(exp.AggFunc) for expression in outer_scope.expression.expressions)
353    ):
354        return
355
356    outer_scope.expression.set("order", inner_scope.expression.args.get("order"))
357
358
359def _merge_hints(outer_scope, inner_scope):
360    inner_scope_hint = inner_scope.expression.args.get("hint")
361    if not inner_scope_hint:
362        return
363    outer_scope_hint = outer_scope.expression.args.get("hint")
364    if outer_scope_hint:
365        for hint_expression in inner_scope_hint.expressions:
366            outer_scope_hint.append("expressions", hint_expression)
367    else:
368        outer_scope.expression.set("hint", inner_scope_hint)
369
370
371def _pop_cte(inner_scope):
372    """
373    Remove CTE from the AST.
374
375    Args:
376        inner_scope (sqlglot.optimizer.scope.Scope)
377    """
378    cte = inner_scope.expression.parent
379    with_ = cte.parent
380    if len(with_.expressions) == 1:
381        with_.pop()
382    else:
383        cte.pop()
def merge_subqueries(expression, leave_tables_isolated=False):
 9def merge_subqueries(expression, leave_tables_isolated=False):
10    """
11    Rewrite sqlglot AST to merge derived tables into the outer query.
12
13    This also merges CTEs if they are selected from only once.
14
15    Example:
16        >>> import sqlglot
17        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) JOIN y")
18        >>> merge_subqueries(expression).sql()
19        'SELECT x.a FROM x JOIN y'
20
21    If `leave_tables_isolated` is True, this will not merge inner queries into outer
22    queries if it would result in multiple table selects in a single query:
23        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) JOIN y")
24        >>> merge_subqueries(expression, leave_tables_isolated=True).sql()
25        'SELECT a FROM (SELECT x.a FROM x) JOIN y'
26
27    Inspired by https://dev.mysql.com/doc/refman/8.0/en/derived-table-optimization.html
28
29    Args:
30        expression (sqlglot.Expression): expression to optimize
31        leave_tables_isolated (bool):
32    Returns:
33        sqlglot.Expression: optimized expression
34    """
35    expression = merge_ctes(expression, leave_tables_isolated)
36    expression = merge_derived_tables(expression, leave_tables_isolated)
37    return expression

Rewrite sqlglot AST to merge derived tables into the outer query.

This also merges CTEs if they are selected from only once.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) JOIN y")
>>> merge_subqueries(expression).sql()
'SELECT x.a FROM x JOIN y'

If leave_tables_isolated is True, this will not merge inner queries into outer queries if it would result in multiple table selects in a single query:

expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) JOIN y") merge_subqueries(expression, leave_tables_isolated=True).sql() 'SELECT a FROM (SELECT x.a FROM x) JOIN y'

Inspired by https://dev.mysql.com/doc/refman/8.0/en/derived-table-optimization.html

Arguments:
  • expression (sqlglot.Expression): expression to optimize
  • leave_tables_isolated (bool):
Returns:

sqlglot.Expression: optimized expression

def merge_ctes(expression, leave_tables_isolated=False):
51def merge_ctes(expression, leave_tables_isolated=False):
52    scopes = traverse_scope(expression)
53
54    # All places where we select from CTEs.
55    # We key on the CTE scope so we can detect CTES that are selected from multiple times.
56    cte_selections = defaultdict(list)
57    for outer_scope in scopes:
58        for table, inner_scope in outer_scope.selected_sources.values():
59            if isinstance(inner_scope, Scope) and inner_scope.is_cte:
60                cte_selections[id(inner_scope)].append(
61                    (
62                        outer_scope,
63                        inner_scope,
64                        table,
65                    )
66                )
67
68    singular_cte_selections = [v[0] for k, v in cte_selections.items() if len(v) == 1]
69    for outer_scope, inner_scope, table in singular_cte_selections:
70        from_or_join = table.find_ancestor(exp.From, exp.Join)
71        if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
72            alias = table.alias_or_name
73            _rename_inner_sources(outer_scope, inner_scope, alias)
74            _merge_from(outer_scope, inner_scope, table, alias)
75            _merge_expressions(outer_scope, inner_scope, alias)
76            _merge_joins(outer_scope, inner_scope, from_or_join)
77            _merge_where(outer_scope, inner_scope, from_or_join)
78            _merge_order(outer_scope, inner_scope)
79            _merge_hints(outer_scope, inner_scope)
80            _pop_cte(inner_scope)
81            outer_scope.clear_cache()
82    return expression
def merge_derived_tables(expression, leave_tables_isolated=False):
 85def merge_derived_tables(expression, leave_tables_isolated=False):
 86    for outer_scope in traverse_scope(expression):
 87        for subquery in outer_scope.derived_tables:
 88            from_or_join = subquery.find_ancestor(exp.From, exp.Join)
 89            alias = subquery.alias_or_name
 90            inner_scope = outer_scope.sources[alias]
 91            if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
 92                _rename_inner_sources(outer_scope, inner_scope, alias)
 93                _merge_from(outer_scope, inner_scope, subquery, alias)
 94                _merge_expressions(outer_scope, inner_scope, alias)
 95                _merge_joins(outer_scope, inner_scope, from_or_join)
 96                _merge_where(outer_scope, inner_scope, from_or_join)
 97                _merge_order(outer_scope, inner_scope)
 98                _merge_hints(outer_scope, inner_scope)
 99                outer_scope.clear_cache()
100    return expression