Edit on GitHub

sqlglot.optimizer.merge_subqueries

  1from collections import defaultdict
  2
  3from sqlglot import expressions as exp
  4from sqlglot.helper import find_new_name
  5from sqlglot.optimizer.scope import Scope, traverse_scope
  6from sqlglot.optimizer.simplify import simplify
  7
  8
  9def merge_subqueries(expression, leave_tables_isolated=False):
 10    """
 11    Rewrite sqlglot AST to merge derived tables into the outer query.
 12
 13    This also merges CTEs if they are selected from only once.
 14
 15    Example:
 16        >>> import sqlglot
 17        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) JOIN y")
 18        >>> merge_subqueries(expression).sql()
 19        'SELECT x.a FROM x JOIN y'
 20
 21    If `leave_tables_isolated` is True, this will not merge inner queries into outer
 22    queries if it would result in multiple table selects in a single query:
 23        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) JOIN y")
 24        >>> merge_subqueries(expression, leave_tables_isolated=True).sql()
 25        'SELECT a FROM (SELECT x.a FROM x) JOIN y'
 26
 27    Inspired by https://dev.mysql.com/doc/refman/8.0/en/derived-table-optimization.html
 28
 29    Args:
 30        expression (sqlglot.Expression): expression to optimize
 31        leave_tables_isolated (bool):
 32    Returns:
 33        sqlglot.Expression: optimized expression
 34    """
 35    expression = merge_ctes(expression, leave_tables_isolated)
 36    expression = merge_derived_tables(expression, leave_tables_isolated)
 37    return expression
 38
 39
 40# If a derived table has these Select args, it can't be merged
 41UNMERGABLE_ARGS = set(exp.Select.arg_types) - {
 42    "expressions",
 43    "from",
 44    "joins",
 45    "where",
 46    "order",
 47    "hint",
 48}
 49
 50
 51def merge_ctes(expression, leave_tables_isolated=False):
 52    scopes = traverse_scope(expression)
 53
 54    # All places where we select from CTEs.
 55    # We key on the CTE scope so we can detect CTES that are selected from multiple times.
 56    cte_selections = defaultdict(list)
 57    for outer_scope in scopes:
 58        for table, inner_scope in outer_scope.selected_sources.values():
 59            if isinstance(inner_scope, Scope) and inner_scope.is_cte:
 60                cte_selections[id(inner_scope)].append(
 61                    (
 62                        outer_scope,
 63                        inner_scope,
 64                        table,
 65                    )
 66                )
 67
 68    singular_cte_selections = [v[0] for k, v in cte_selections.items() if len(v) == 1]
 69    for outer_scope, inner_scope, table in singular_cte_selections:
 70        from_or_join = table.find_ancestor(exp.From, exp.Join)
 71        if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
 72            alias = table.alias_or_name
 73            _rename_inner_sources(outer_scope, inner_scope, alias)
 74            _merge_from(outer_scope, inner_scope, table, alias)
 75            _merge_expressions(outer_scope, inner_scope, alias)
 76            _merge_joins(outer_scope, inner_scope, from_or_join)
 77            _merge_where(outer_scope, inner_scope, from_or_join)
 78            _merge_order(outer_scope, inner_scope)
 79            _merge_hints(outer_scope, inner_scope)
 80            _pop_cte(inner_scope)
 81            outer_scope.clear_cache()
 82    return expression
 83
 84
 85def merge_derived_tables(expression, leave_tables_isolated=False):
 86    for outer_scope in traverse_scope(expression):
 87        for subquery in outer_scope.derived_tables:
 88            from_or_join = subquery.find_ancestor(exp.From, exp.Join)
 89            alias = subquery.alias_or_name
 90            inner_scope = outer_scope.sources[alias]
 91            if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
 92                _rename_inner_sources(outer_scope, inner_scope, alias)
 93                _merge_from(outer_scope, inner_scope, subquery, alias)
 94                _merge_expressions(outer_scope, inner_scope, alias)
 95                _merge_joins(outer_scope, inner_scope, from_or_join)
 96                _merge_where(outer_scope, inner_scope, from_or_join)
 97                _merge_order(outer_scope, inner_scope)
 98                _merge_hints(outer_scope, inner_scope)
 99                outer_scope.clear_cache()
100    return expression
101
102
103def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
104    """
105    Return True if `inner_select` can be merged into outer query.
106
107    Args:
108        outer_scope (Scope)
109        inner_scope (Scope)
110        leave_tables_isolated (bool)
111        from_or_join (exp.From|exp.Join)
112    Returns:
113        bool: True if can be merged
114    """
115    inner_select = inner_scope.expression.unnest()
116
117    def _is_a_window_expression_in_unmergable_operation():
118        window_expressions = inner_select.find_all(exp.Window)
119        window_alias_names = {window.parent.alias_or_name for window in window_expressions}
120        inner_select_name = inner_select.parent.alias_or_name
121        unmergable_window_columns = [
122            column
123            for column in outer_scope.columns
124            if column.find_ancestor(
125                exp.Where, exp.Group, exp.Order, exp.Join, exp.Having, exp.AggFunc
126            )
127        ]
128        window_expressions_in_unmergable = [
129            column
130            for column in unmergable_window_columns
131            if column.table == inner_select_name and column.name in window_alias_names
132        ]
133        return any(window_expressions_in_unmergable)
134
135    def _outer_select_joins_on_inner_select_join():
136        """
137        All columns from the inner select in the ON clause must be from the first FROM table.
138
139        That is, this can be merged:
140            SELECT * FROM x JOIN (SELECT y.a AS a FROM y JOIN z) AS q ON x.a = q.a
141                                         ^^^           ^
142        But this can't:
143            SELECT * FROM x JOIN (SELECT z.a AS a FROM y JOIN z) AS q ON x.a = q.a
144                                         ^^^                  ^
145        """
146        if not isinstance(from_or_join, exp.Join):
147            return False
148
149        alias = from_or_join.this.alias_or_name
150
151        on = from_or_join.args.get("on")
152        if not on:
153            return False
154        selections = [c.name for c in on.find_all(exp.Column) if c.table == alias]
155        inner_from = inner_scope.expression.args.get("from")
156        if not inner_from:
157            return False
158        inner_from_table = inner_from.expressions[0].alias_or_name
159        inner_projections = {s.alias_or_name: s for s in inner_scope.selects}
160        return any(
161            col.table != inner_from_table
162            for selection in selections
163            for col in inner_projections[selection].find_all(exp.Column)
164        )
165
166    return (
167        isinstance(outer_scope.expression, exp.Select)
168        and isinstance(inner_select, exp.Select)
169        and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS)
170        and inner_select.args.get("from")
171        and not any(e.find(exp.AggFunc, exp.Select) for e in inner_select.expressions)
172        and not (leave_tables_isolated and len(outer_scope.selected_sources) > 1)
173        and not (
174            isinstance(from_or_join, exp.Join)
175            and inner_select.args.get("where")
176            and from_or_join.side in {"FULL", "LEFT", "RIGHT"}
177        )
178        and not (
179            isinstance(from_or_join, exp.From)
180            and inner_select.args.get("where")
181            and any(
182                j.side in {"FULL", "RIGHT"} for j in outer_scope.expression.args.get("joins", [])
183            )
184        )
185        and not _outer_select_joins_on_inner_select_join()
186        and not _is_a_window_expression_in_unmergable_operation()
187    )
188
189
190def _rename_inner_sources(outer_scope, inner_scope, alias):
191    """
192    Renames any sources in the inner query that conflict with names in the outer query.
193
194    Args:
195        outer_scope (sqlglot.optimizer.scope.Scope)
196        inner_scope (sqlglot.optimizer.scope.Scope)
197        alias (str)
198    """
199    taken = set(outer_scope.selected_sources)
200    conflicts = taken.intersection(set(inner_scope.selected_sources))
201    conflicts -= {alias}
202
203    for conflict in conflicts:
204        new_name = find_new_name(taken, conflict)
205
206        source, _ = inner_scope.selected_sources[conflict]
207        new_alias = exp.to_identifier(new_name)
208
209        if isinstance(source, exp.Subquery):
210            source.set("alias", exp.TableAlias(this=new_alias))
211        elif isinstance(source, exp.Table) and source.alias:
212            source.set("alias", new_alias)
213        elif isinstance(source, exp.Table):
214            source.replace(exp.alias_(source.copy(), new_alias))
215
216        for column in inner_scope.source_columns(conflict):
217            column.set("table", exp.to_identifier(new_name))
218
219        inner_scope.rename_source(conflict, new_name)
220
221
222def _merge_from(outer_scope, inner_scope, node_to_replace, alias):
223    """
224    Merge FROM clause of inner query into outer query.
225
226    Args:
227        outer_scope (sqlglot.optimizer.scope.Scope)
228        inner_scope (sqlglot.optimizer.scope.Scope)
229        node_to_replace (exp.Subquery|exp.Table)
230        alias (str)
231    """
232    new_subquery = inner_scope.expression.args.get("from").expressions[0]
233    node_to_replace.replace(new_subquery)
234    for join_hint in outer_scope.join_hints:
235        tables = join_hint.find_all(exp.Table)
236        for table in tables:
237            if table.alias_or_name == node_to_replace.alias_or_name:
238                table.set("this", exp.to_identifier(new_subquery.alias_or_name))
239    outer_scope.remove_source(alias)
240    outer_scope.add_source(
241        new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name]
242    )
243
244
245def _merge_joins(outer_scope, inner_scope, from_or_join):
246    """
247    Merge JOIN clauses of inner query into outer query.
248
249    Args:
250        outer_scope (sqlglot.optimizer.scope.Scope)
251        inner_scope (sqlglot.optimizer.scope.Scope)
252        from_or_join (exp.From|exp.Join)
253    """
254
255    new_joins = []
256    comma_joins = inner_scope.expression.args.get("from").expressions[1:]
257    for subquery in comma_joins:
258        new_joins.append(exp.Join(this=subquery, kind="CROSS"))
259        outer_scope.add_source(subquery.alias_or_name, inner_scope.sources[subquery.alias_or_name])
260
261    joins = inner_scope.expression.args.get("joins") or []
262    for join in joins:
263        new_joins.append(join)
264        outer_scope.add_source(join.alias_or_name, inner_scope.sources[join.alias_or_name])
265
266    if new_joins:
267        outer_joins = outer_scope.expression.args.get("joins", [])
268
269        # Maintain the join order
270        if isinstance(from_or_join, exp.From):
271            position = 0
272        else:
273            position = outer_joins.index(from_or_join) + 1
274        outer_joins[position:position] = new_joins
275
276        outer_scope.expression.set("joins", outer_joins)
277
278
279def _merge_expressions(outer_scope, inner_scope, alias):
280    """
281    Merge projections of inner query into outer query.
282
283    Args:
284        outer_scope (sqlglot.optimizer.scope.Scope)
285        inner_scope (sqlglot.optimizer.scope.Scope)
286        alias (str)
287    """
288    # Collect all columns that reference the alias of the inner query
289    outer_columns = defaultdict(list)
290    for column in outer_scope.columns:
291        if column.table == alias:
292            outer_columns[column.name].append(column)
293
294    # Replace columns with the projection expression in the inner query
295    for expression in inner_scope.expression.expressions:
296        projection_name = expression.alias_or_name
297        if not projection_name:
298            continue
299        columns_to_replace = outer_columns.get(projection_name, [])
300        for column in columns_to_replace:
301            column.replace(expression.unalias().copy())
302
303
304def _merge_where(outer_scope, inner_scope, from_or_join):
305    """
306    Merge WHERE clause of inner query into outer query.
307
308    Args:
309        outer_scope (sqlglot.optimizer.scope.Scope)
310        inner_scope (sqlglot.optimizer.scope.Scope)
311        from_or_join (exp.From|exp.Join)
312    """
313    where = inner_scope.expression.args.get("where")
314    if not where or not where.this:
315        return
316
317    if isinstance(from_or_join, exp.Join):
318        # Merge predicates from an outer join to the ON clause
319        from_or_join.on(where.this, copy=False)
320        from_or_join.set("on", simplify(from_or_join.args.get("on")))
321    else:
322        outer_scope.expression.where(where.this, copy=False)
323        outer_scope.expression.set("where", simplify(outer_scope.expression.args.get("where")))
324
325
326def _merge_order(outer_scope, inner_scope):
327    """
328    Merge ORDER clause of inner query into outer query.
329
330    Args:
331        outer_scope (sqlglot.optimizer.scope.Scope)
332        inner_scope (sqlglot.optimizer.scope.Scope)
333    """
334    if (
335        any(
336            outer_scope.expression.args.get(arg) for arg in ["group", "distinct", "having", "order"]
337        )
338        or len(outer_scope.selected_sources) != 1
339        or any(expression.find(exp.AggFunc) for expression in outer_scope.expression.expressions)
340    ):
341        return
342
343    outer_scope.expression.set("order", inner_scope.expression.args.get("order"))
344
345
346def _merge_hints(outer_scope, inner_scope):
347    inner_scope_hint = inner_scope.expression.args.get("hint")
348    if not inner_scope_hint:
349        return
350    outer_scope_hint = outer_scope.expression.args.get("hint")
351    if outer_scope_hint:
352        for hint_expression in inner_scope_hint.expressions:
353            outer_scope_hint.append("expressions", hint_expression)
354    else:
355        outer_scope.expression.set("hint", inner_scope_hint)
356
357
358def _pop_cte(inner_scope):
359    """
360    Remove CTE from the AST.
361
362    Args:
363        inner_scope (sqlglot.optimizer.scope.Scope)
364    """
365    cte = inner_scope.expression.parent
366    with_ = cte.parent
367    if len(with_.expressions) == 1:
368        with_.pop()
369    else:
370        cte.pop()
def merge_subqueries(expression, leave_tables_isolated=False):
10def merge_subqueries(expression, leave_tables_isolated=False):
11    """
12    Rewrite sqlglot AST to merge derived tables into the outer query.
13
14    This also merges CTEs if they are selected from only once.
15
16    Example:
17        >>> import sqlglot
18        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) JOIN y")
19        >>> merge_subqueries(expression).sql()
20        'SELECT x.a FROM x JOIN y'
21
22    If `leave_tables_isolated` is True, this will not merge inner queries into outer
23    queries if it would result in multiple table selects in a single query:
24        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) JOIN y")
25        >>> merge_subqueries(expression, leave_tables_isolated=True).sql()
26        'SELECT a FROM (SELECT x.a FROM x) JOIN y'
27
28    Inspired by https://dev.mysql.com/doc/refman/8.0/en/derived-table-optimization.html
29
30    Args:
31        expression (sqlglot.Expression): expression to optimize
32        leave_tables_isolated (bool):
33    Returns:
34        sqlglot.Expression: optimized expression
35    """
36    expression = merge_ctes(expression, leave_tables_isolated)
37    expression = merge_derived_tables(expression, leave_tables_isolated)
38    return expression

Rewrite sqlglot AST to merge derived tables into the outer query.

This also merges CTEs if they are selected from only once.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) JOIN y")
>>> merge_subqueries(expression).sql()
'SELECT x.a FROM x JOIN y'

If leave_tables_isolated is True, this will not merge inner queries into outer queries if it would result in multiple table selects in a single query:

expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) JOIN y") merge_subqueries(expression, leave_tables_isolated=True).sql() 'SELECT a FROM (SELECT x.a FROM x) JOIN y'

Inspired by https://dev.mysql.com/doc/refman/8.0/en/derived-table-optimization.html

Arguments:
  • expression (sqlglot.Expression): expression to optimize
  • leave_tables_isolated (bool):
Returns:

sqlglot.Expression: optimized expression

def merge_ctes(expression, leave_tables_isolated=False):
52def merge_ctes(expression, leave_tables_isolated=False):
53    scopes = traverse_scope(expression)
54
55    # All places where we select from CTEs.
56    # We key on the CTE scope so we can detect CTES that are selected from multiple times.
57    cte_selections = defaultdict(list)
58    for outer_scope in scopes:
59        for table, inner_scope in outer_scope.selected_sources.values():
60            if isinstance(inner_scope, Scope) and inner_scope.is_cte:
61                cte_selections[id(inner_scope)].append(
62                    (
63                        outer_scope,
64                        inner_scope,
65                        table,
66                    )
67                )
68
69    singular_cte_selections = [v[0] for k, v in cte_selections.items() if len(v) == 1]
70    for outer_scope, inner_scope, table in singular_cte_selections:
71        from_or_join = table.find_ancestor(exp.From, exp.Join)
72        if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
73            alias = table.alias_or_name
74            _rename_inner_sources(outer_scope, inner_scope, alias)
75            _merge_from(outer_scope, inner_scope, table, alias)
76            _merge_expressions(outer_scope, inner_scope, alias)
77            _merge_joins(outer_scope, inner_scope, from_or_join)
78            _merge_where(outer_scope, inner_scope, from_or_join)
79            _merge_order(outer_scope, inner_scope)
80            _merge_hints(outer_scope, inner_scope)
81            _pop_cte(inner_scope)
82            outer_scope.clear_cache()
83    return expression
def merge_derived_tables(expression, leave_tables_isolated=False):
 86def merge_derived_tables(expression, leave_tables_isolated=False):
 87    for outer_scope in traverse_scope(expression):
 88        for subquery in outer_scope.derived_tables:
 89            from_or_join = subquery.find_ancestor(exp.From, exp.Join)
 90            alias = subquery.alias_or_name
 91            inner_scope = outer_scope.sources[alias]
 92            if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
 93                _rename_inner_sources(outer_scope, inner_scope, alias)
 94                _merge_from(outer_scope, inner_scope, subquery, alias)
 95                _merge_expressions(outer_scope, inner_scope, alias)
 96                _merge_joins(outer_scope, inner_scope, from_or_join)
 97                _merge_where(outer_scope, inner_scope, from_or_join)
 98                _merge_order(outer_scope, inner_scope)
 99                _merge_hints(outer_scope, inner_scope)
100                outer_scope.clear_cache()
101    return expression