Edit on GitHub

sqlglot.optimizer.merge_subqueries

  1from collections import defaultdict
  2
  3from sqlglot import expressions as exp
  4from sqlglot.helper import find_new_name
  5from sqlglot.optimizer.scope import Scope, traverse_scope
  6
  7
  8def merge_subqueries(expression, leave_tables_isolated=False):
  9    """
 10    Rewrite sqlglot AST to merge derived tables into the outer query.
 11
 12    This also merges CTEs if they are selected from only once.
 13
 14    Example:
 15        >>> import sqlglot
 16        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y")
 17        >>> merge_subqueries(expression).sql()
 18        'SELECT x.a FROM x CROSS JOIN y'
 19
 20    If `leave_tables_isolated` is True, this will not merge inner queries into outer
 21    queries if it would result in multiple table selects in a single query:
 22        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y")
 23        >>> merge_subqueries(expression, leave_tables_isolated=True).sql()
 24        'SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y'
 25
 26    Inspired by https://dev.mysql.com/doc/refman/8.0/en/derived-table-optimization.html
 27
 28    Args:
 29        expression (sqlglot.Expression): expression to optimize
 30        leave_tables_isolated (bool):
 31    Returns:
 32        sqlglot.Expression: optimized expression
 33    """
 34    expression = merge_ctes(expression, leave_tables_isolated)
 35    expression = merge_derived_tables(expression, leave_tables_isolated)
 36    return expression
 37
 38
 39# If a derived table has these Select args, it can't be merged
 40UNMERGABLE_ARGS = set(exp.Select.arg_types) - {
 41    "expressions",
 42    "from",
 43    "joins",
 44    "where",
 45    "order",
 46    "hint",
 47}
 48
 49
 50# Projections in the outer query that are instances of these types can be replaced
 51# without getting wrapped in parentheses, because the precedence won't be altered.
 52SAFE_TO_REPLACE_UNWRAPPED = (
 53    exp.Column,
 54    exp.EQ,
 55    exp.Func,
 56    exp.NEQ,
 57    exp.Paren,
 58)
 59
 60
 61def merge_ctes(expression, leave_tables_isolated=False):
 62    scopes = traverse_scope(expression)
 63
 64    # All places where we select from CTEs.
 65    # We key on the CTE scope so we can detect CTES that are selected from multiple times.
 66    cte_selections = defaultdict(list)
 67    for outer_scope in scopes:
 68        for table, inner_scope in outer_scope.selected_sources.values():
 69            if isinstance(inner_scope, Scope) and inner_scope.is_cte:
 70                cte_selections[id(inner_scope)].append(
 71                    (
 72                        outer_scope,
 73                        inner_scope,
 74                        table,
 75                    )
 76                )
 77
 78    singular_cte_selections = [v[0] for k, v in cte_selections.items() if len(v) == 1]
 79    for outer_scope, inner_scope, table in singular_cte_selections:
 80        from_or_join = table.find_ancestor(exp.From, exp.Join)
 81        if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
 82            alias = table.alias_or_name
 83            _rename_inner_sources(outer_scope, inner_scope, alias)
 84            _merge_from(outer_scope, inner_scope, table, alias)
 85            _merge_expressions(outer_scope, inner_scope, alias)
 86            _merge_joins(outer_scope, inner_scope, from_or_join)
 87            _merge_where(outer_scope, inner_scope, from_or_join)
 88            _merge_order(outer_scope, inner_scope)
 89            _merge_hints(outer_scope, inner_scope)
 90            _pop_cte(inner_scope)
 91            outer_scope.clear_cache()
 92    return expression
 93
 94
 95def merge_derived_tables(expression, leave_tables_isolated=False):
 96    for outer_scope in traverse_scope(expression):
 97        for subquery in outer_scope.derived_tables:
 98            from_or_join = subquery.find_ancestor(exp.From, exp.Join)
 99            alias = subquery.alias_or_name
100            inner_scope = outer_scope.sources[alias]
101            if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
102                _rename_inner_sources(outer_scope, inner_scope, alias)
103                _merge_from(outer_scope, inner_scope, subquery, alias)
104                _merge_expressions(outer_scope, inner_scope, alias)
105                _merge_joins(outer_scope, inner_scope, from_or_join)
106                _merge_where(outer_scope, inner_scope, from_or_join)
107                _merge_order(outer_scope, inner_scope)
108                _merge_hints(outer_scope, inner_scope)
109                outer_scope.clear_cache()
110
111    return expression
112
113
114def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
115    """
116    Return True if `inner_select` can be merged into outer query.
117
118    Args:
119        outer_scope (Scope)
120        inner_scope (Scope)
121        leave_tables_isolated (bool)
122        from_or_join (exp.From|exp.Join)
123    Returns:
124        bool: True if can be merged
125    """
126    inner_select = inner_scope.expression.unnest()
127
128    def _is_a_window_expression_in_unmergable_operation():
129        window_expressions = inner_select.find_all(exp.Window)
130        window_alias_names = {window.parent.alias_or_name for window in window_expressions}
131        inner_select_name = from_or_join.alias_or_name
132        unmergable_window_columns = [
133            column
134            for column in outer_scope.columns
135            if column.find_ancestor(
136                exp.Where, exp.Group, exp.Order, exp.Join, exp.Having, exp.AggFunc
137            )
138        ]
139        window_expressions_in_unmergable = [
140            column
141            for column in unmergable_window_columns
142            if column.table == inner_select_name and column.name in window_alias_names
143        ]
144        return any(window_expressions_in_unmergable)
145
146    def _outer_select_joins_on_inner_select_join():
147        """
148        All columns from the inner select in the ON clause must be from the first FROM table.
149
150        That is, this can be merged:
151            SELECT * FROM x JOIN (SELECT y.a AS a FROM y JOIN z) AS q ON x.a = q.a
152                                         ^^^           ^
153        But this can't:
154            SELECT * FROM x JOIN (SELECT z.a AS a FROM y JOIN z) AS q ON x.a = q.a
155                                         ^^^                  ^
156        """
157        if not isinstance(from_or_join, exp.Join):
158            return False
159
160        alias = from_or_join.alias_or_name
161
162        on = from_or_join.args.get("on")
163        if not on:
164            return False
165        selections = [c.name for c in on.find_all(exp.Column) if c.table == alias]
166        inner_from = inner_scope.expression.args.get("from")
167        if not inner_from:
168            return False
169        inner_from_table = inner_from.alias_or_name
170        inner_projections = {s.alias_or_name: s for s in inner_scope.expression.selects}
171        return any(
172            col.table != inner_from_table
173            for selection in selections
174            for col in inner_projections[selection].find_all(exp.Column)
175        )
176
177    def _is_recursive():
178        # Recursive CTEs look like this:
179        #     WITH RECURSIVE cte AS (
180        #       SELECT * FROM x  <-- inner scope
181        #       UNION ALL
182        #       SELECT * FROM cte  <-- outer scope
183        #     )
184        cte = inner_scope.expression.parent
185        node = outer_scope.expression.parent
186
187        while node:
188            if node is cte:
189                return True
190            node = node.parent
191        return False
192
193    return (
194        isinstance(outer_scope.expression, exp.Select)
195        and not outer_scope.expression.is_star
196        and isinstance(inner_select, exp.Select)
197        and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS)
198        and inner_select.args.get("from")
199        and not outer_scope.pivots
200        and not any(e.find(exp.AggFunc, exp.Select, exp.Explode) for e in inner_select.expressions)
201        and not (leave_tables_isolated and len(outer_scope.selected_sources) > 1)
202        and not (
203            isinstance(from_or_join, exp.Join)
204            and inner_select.args.get("where")
205            and from_or_join.side in ("FULL", "LEFT", "RIGHT")
206        )
207        and not (
208            isinstance(from_or_join, exp.From)
209            and inner_select.args.get("where")
210            and any(
211                j.side in ("FULL", "RIGHT") for j in outer_scope.expression.args.get("joins", [])
212            )
213        )
214        and not _outer_select_joins_on_inner_select_join()
215        and not _is_a_window_expression_in_unmergable_operation()
216        and not _is_recursive()
217        and not (inner_select.args.get("order") and outer_scope.is_union)
218    )
219
220
221def _rename_inner_sources(outer_scope, inner_scope, alias):
222    """
223    Renames any sources in the inner query that conflict with names in the outer query.
224
225    Args:
226        outer_scope (sqlglot.optimizer.scope.Scope)
227        inner_scope (sqlglot.optimizer.scope.Scope)
228        alias (str)
229    """
230    taken = set(outer_scope.selected_sources)
231    conflicts = taken.intersection(set(inner_scope.selected_sources))
232    conflicts -= {alias}
233
234    for conflict in conflicts:
235        new_name = find_new_name(taken, conflict)
236
237        source, _ = inner_scope.selected_sources[conflict]
238        new_alias = exp.to_identifier(new_name)
239
240        if isinstance(source, exp.Subquery):
241            source.set("alias", exp.TableAlias(this=new_alias))
242        elif isinstance(source, exp.Table) and source.alias:
243            source.set("alias", new_alias)
244        elif isinstance(source, exp.Table):
245            source.replace(exp.alias_(source, new_alias))
246
247        for column in inner_scope.source_columns(conflict):
248            column.set("table", exp.to_identifier(new_name))
249
250        inner_scope.rename_source(conflict, new_name)
251
252
253def _merge_from(outer_scope, inner_scope, node_to_replace, alias):
254    """
255    Merge FROM clause of inner query into outer query.
256
257    Args:
258        outer_scope (sqlglot.optimizer.scope.Scope)
259        inner_scope (sqlglot.optimizer.scope.Scope)
260        node_to_replace (exp.Subquery|exp.Table)
261        alias (str)
262    """
263    new_subquery = inner_scope.expression.args["from"].this
264    new_subquery.set("joins", node_to_replace.args.get("joins"))
265    node_to_replace.replace(new_subquery)
266    for join_hint in outer_scope.join_hints:
267        tables = join_hint.find_all(exp.Table)
268        for table in tables:
269            if table.alias_or_name == node_to_replace.alias_or_name:
270                table.set("this", exp.to_identifier(new_subquery.alias_or_name))
271    outer_scope.remove_source(alias)
272    outer_scope.add_source(
273        new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name]
274    )
275
276
277def _merge_joins(outer_scope, inner_scope, from_or_join):
278    """
279    Merge JOIN clauses of inner query into outer query.
280
281    Args:
282        outer_scope (sqlglot.optimizer.scope.Scope)
283        inner_scope (sqlglot.optimizer.scope.Scope)
284        from_or_join (exp.From|exp.Join)
285    """
286
287    new_joins = []
288
289    joins = inner_scope.expression.args.get("joins") or []
290    for join in joins:
291        new_joins.append(join)
292        outer_scope.add_source(join.alias_or_name, inner_scope.sources[join.alias_or_name])
293
294    if new_joins:
295        outer_joins = outer_scope.expression.args.get("joins", [])
296
297        # Maintain the join order
298        if isinstance(from_or_join, exp.From):
299            position = 0
300        else:
301            position = outer_joins.index(from_or_join) + 1
302        outer_joins[position:position] = new_joins
303
304        outer_scope.expression.set("joins", outer_joins)
305
306
307def _merge_expressions(outer_scope, inner_scope, alias):
308    """
309    Merge projections of inner query into outer query.
310
311    Args:
312        outer_scope (sqlglot.optimizer.scope.Scope)
313        inner_scope (sqlglot.optimizer.scope.Scope)
314        alias (str)
315    """
316    # Collect all columns that reference the alias of the inner query
317    outer_columns = defaultdict(list)
318    for column in outer_scope.columns:
319        if column.table == alias:
320            outer_columns[column.name].append(column)
321
322    # Replace columns with the projection expression in the inner query
323    for expression in inner_scope.expression.expressions:
324        projection_name = expression.alias_or_name
325        if not projection_name:
326            continue
327        columns_to_replace = outer_columns.get(projection_name, [])
328
329        expression = expression.unalias()
330        must_wrap_expression = not isinstance(expression, SAFE_TO_REPLACE_UNWRAPPED)
331
332        for column in columns_to_replace:
333            # Ensures we don't alter the intended operator precedence if there's additional
334            # context surrounding the outer expression (i.e. it's not a simple projection).
335            if isinstance(column.parent, (exp.Unary, exp.Binary)) and must_wrap_expression:
336                expression = exp.paren(expression, copy=False)
337
338            column.replace(expression.copy())
339
340
341def _merge_where(outer_scope, inner_scope, from_or_join):
342    """
343    Merge WHERE clause of inner query into outer query.
344
345    Args:
346        outer_scope (sqlglot.optimizer.scope.Scope)
347        inner_scope (sqlglot.optimizer.scope.Scope)
348        from_or_join (exp.From|exp.Join)
349    """
350    where = inner_scope.expression.args.get("where")
351    if not where or not where.this:
352        return
353
354    expression = outer_scope.expression
355
356    if isinstance(from_or_join, exp.Join):
357        # Merge predicates from an outer join to the ON clause
358        # if it only has columns that are already joined
359        from_ = expression.args.get("from")
360        sources = {from_.alias_or_name} if from_ else {}
361
362        for join in expression.args["joins"]:
363            source = join.alias_or_name
364            sources.add(source)
365            if source == from_or_join.alias_or_name:
366                break
367
368        if exp.column_table_names(where.this) <= sources:
369            from_or_join.on(where.this, copy=False)
370            from_or_join.set("on", from_or_join.args.get("on"))
371            return
372
373    expression.where(where.this, copy=False)
374
375
376def _merge_order(outer_scope, inner_scope):
377    """
378    Merge ORDER clause of inner query into outer query.
379
380    Args:
381        outer_scope (sqlglot.optimizer.scope.Scope)
382        inner_scope (sqlglot.optimizer.scope.Scope)
383    """
384    if (
385        any(
386            outer_scope.expression.args.get(arg) for arg in ["group", "distinct", "having", "order"]
387        )
388        or len(outer_scope.selected_sources) != 1
389        or any(expression.find(exp.AggFunc) for expression in outer_scope.expression.expressions)
390    ):
391        return
392
393    outer_scope.expression.set("order", inner_scope.expression.args.get("order"))
394
395
396def _merge_hints(outer_scope, inner_scope):
397    inner_scope_hint = inner_scope.expression.args.get("hint")
398    if not inner_scope_hint:
399        return
400    outer_scope_hint = outer_scope.expression.args.get("hint")
401    if outer_scope_hint:
402        for hint_expression in inner_scope_hint.expressions:
403            outer_scope_hint.append("expressions", hint_expression)
404    else:
405        outer_scope.expression.set("hint", inner_scope_hint)
406
407
408def _pop_cte(inner_scope):
409    """
410    Remove CTE from the AST.
411
412    Args:
413        inner_scope (sqlglot.optimizer.scope.Scope)
414    """
415    cte = inner_scope.expression.parent
416    with_ = cte.parent
417    if len(with_.expressions) == 1:
418        with_.pop()
419    else:
420        cte.pop()
def merge_subqueries(expression, leave_tables_isolated=False):
 9def merge_subqueries(expression, leave_tables_isolated=False):
10    """
11    Rewrite sqlglot AST to merge derived tables into the outer query.
12
13    This also merges CTEs if they are selected from only once.
14
15    Example:
16        >>> import sqlglot
17        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y")
18        >>> merge_subqueries(expression).sql()
19        'SELECT x.a FROM x CROSS JOIN y'
20
21    If `leave_tables_isolated` is True, this will not merge inner queries into outer
22    queries if it would result in multiple table selects in a single query:
23        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y")
24        >>> merge_subqueries(expression, leave_tables_isolated=True).sql()
25        'SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y'
26
27    Inspired by https://dev.mysql.com/doc/refman/8.0/en/derived-table-optimization.html
28
29    Args:
30        expression (sqlglot.Expression): expression to optimize
31        leave_tables_isolated (bool):
32    Returns:
33        sqlglot.Expression: optimized expression
34    """
35    expression = merge_ctes(expression, leave_tables_isolated)
36    expression = merge_derived_tables(expression, leave_tables_isolated)
37    return expression

Rewrite sqlglot AST to merge derived tables into the outer query.

This also merges CTEs if they are selected from only once.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y")
>>> merge_subqueries(expression).sql()
'SELECT x.a FROM x CROSS JOIN y'

If leave_tables_isolated is True, this will not merge inner queries into outer queries if it would result in multiple table selects in a single query:

expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y") merge_subqueries(expression, leave_tables_isolated=True).sql() 'SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y'

Inspired by https://dev.mysql.com/doc/refman/8.0/en/derived-table-optimization.html

Arguments:
  • expression (sqlglot.Expression): expression to optimize
  • leave_tables_isolated (bool):
Returns:

sqlglot.Expression: optimized expression

UNMERGABLE_ARGS = {'offset', 'prewhere', 'match', 'locks', 'qualify', 'windows', 'pivots', 'cluster', 'settings', 'having', 'group', 'options', 'distinct', 'with', 'distribute', 'sample', 'format', 'connect', 'laterals', 'limit', 'sort', 'into', 'kind'}
SAFE_TO_REPLACE_UNWRAPPED = (<class 'sqlglot.expressions.Column'>, <class 'sqlglot.expressions.EQ'>, <class 'sqlglot.expressions.Func'>, <class 'sqlglot.expressions.NEQ'>, <class 'sqlglot.expressions.Paren'>)
def merge_ctes(expression, leave_tables_isolated=False):
62def merge_ctes(expression, leave_tables_isolated=False):
63    scopes = traverse_scope(expression)
64
65    # All places where we select from CTEs.
66    # We key on the CTE scope so we can detect CTES that are selected from multiple times.
67    cte_selections = defaultdict(list)
68    for outer_scope in scopes:
69        for table, inner_scope in outer_scope.selected_sources.values():
70            if isinstance(inner_scope, Scope) and inner_scope.is_cte:
71                cte_selections[id(inner_scope)].append(
72                    (
73                        outer_scope,
74                        inner_scope,
75                        table,
76                    )
77                )
78
79    singular_cte_selections = [v[0] for k, v in cte_selections.items() if len(v) == 1]
80    for outer_scope, inner_scope, table in singular_cte_selections:
81        from_or_join = table.find_ancestor(exp.From, exp.Join)
82        if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
83            alias = table.alias_or_name
84            _rename_inner_sources(outer_scope, inner_scope, alias)
85            _merge_from(outer_scope, inner_scope, table, alias)
86            _merge_expressions(outer_scope, inner_scope, alias)
87            _merge_joins(outer_scope, inner_scope, from_or_join)
88            _merge_where(outer_scope, inner_scope, from_or_join)
89            _merge_order(outer_scope, inner_scope)
90            _merge_hints(outer_scope, inner_scope)
91            _pop_cte(inner_scope)
92            outer_scope.clear_cache()
93    return expression
def merge_derived_tables(expression, leave_tables_isolated=False):
 96def merge_derived_tables(expression, leave_tables_isolated=False):
 97    for outer_scope in traverse_scope(expression):
 98        for subquery in outer_scope.derived_tables:
 99            from_or_join = subquery.find_ancestor(exp.From, exp.Join)
100            alias = subquery.alias_or_name
101            inner_scope = outer_scope.sources[alias]
102            if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
103                _rename_inner_sources(outer_scope, inner_scope, alias)
104                _merge_from(outer_scope, inner_scope, subquery, alias)
105                _merge_expressions(outer_scope, inner_scope, alias)
106                _merge_joins(outer_scope, inner_scope, from_or_join)
107                _merge_where(outer_scope, inner_scope, from_or_join)
108                _merge_order(outer_scope, inner_scope)
109                _merge_hints(outer_scope, inner_scope)
110                outer_scope.clear_cache()
111
112    return expression