Edit on GitHub

sqlglot.optimizer.eliminate_subqueries

  1import itertools
  2
  3from sqlglot import expressions as exp
  4from sqlglot.helper import find_new_name
  5from sqlglot.optimizer.scope import build_scope
  6
  7
  8def eliminate_subqueries(expression):
  9    """
 10    Rewrite derived tables as CTES, deduplicating if possible.
 11
 12    Example:
 13        >>> import sqlglot
 14        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y")
 15        >>> eliminate_subqueries(expression).sql()
 16        'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y'
 17
 18    This also deduplicates common subqueries:
 19        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y CROSS JOIN (SELECT * FROM x) AS z")
 20        >>> eliminate_subqueries(expression).sql()
 21        'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y CROSS JOIN y AS z'
 22
 23    Args:
 24        expression (sqlglot.Expression): expression
 25    Returns:
 26        sqlglot.Expression: expression
 27    """
 28    if isinstance(expression, exp.Subquery):
 29        # It's possible to have subqueries at the root, e.g. (SELECT * FROM x) LIMIT 1
 30        eliminate_subqueries(expression.this)
 31        return expression
 32
 33    root = build_scope(expression)
 34
 35    if not root:
 36        return expression
 37
 38    # Map of alias->Scope|Table
 39    # These are all aliases that are already used in the expression.
 40    # We don't want to create new CTEs that conflict with these names.
 41    taken = {}
 42
 43    # All CTE aliases in the root scope are taken
 44    for scope in root.cte_scopes:
 45        taken[scope.expression.parent.alias] = scope
 46
 47    # All table names are taken
 48    for scope in root.traverse():
 49        taken.update(
 50            {
 51                source.name: source
 52                for _, source in scope.sources.items()
 53                if isinstance(source, exp.Table)
 54            }
 55        )
 56
 57    # Map of Expression->alias
 58    # Existing CTES in the root expression. We'll use this for deduplication.
 59    existing_ctes = {}
 60
 61    with_ = root.expression.args.get("with")
 62    recursive = False
 63    if with_:
 64        recursive = with_.args.get("recursive")
 65        for cte in with_.expressions:
 66            existing_ctes[cte.this] = cte.alias
 67    new_ctes = []
 68
 69    # We're adding more CTEs, but we want to maintain the DAG order.
 70    # Derived tables within an existing CTE need to come before the existing CTE.
 71    for cte_scope in root.cte_scopes:
 72        # Append all the new CTEs from this existing CTE
 73        for scope in cte_scope.traverse():
 74            if scope is cte_scope:
 75                # Don't try to eliminate this CTE itself
 76                continue
 77            new_cte = _eliminate(scope, existing_ctes, taken)
 78            if new_cte:
 79                new_ctes.append(new_cte)
 80
 81        # Append the existing CTE itself
 82        new_ctes.append(cte_scope.expression.parent)
 83
 84    # Now append the rest
 85    for scope in itertools.chain(root.union_scopes, root.subquery_scopes, root.table_scopes):
 86        for child_scope in scope.traverse():
 87            new_cte = _eliminate(child_scope, existing_ctes, taken)
 88            if new_cte:
 89                new_ctes.append(new_cte)
 90
 91    if new_ctes:
 92        query = expression.expression if isinstance(expression, exp.DDL) else expression
 93        query.set("with", exp.With(expressions=new_ctes, recursive=recursive))
 94
 95    return expression
 96
 97
 98def _eliminate(scope, existing_ctes, taken):
 99    if scope.is_derived_table:
100        return _eliminate_derived_table(scope, existing_ctes, taken)
101
102    if scope.is_cte:
103        return _eliminate_cte(scope, existing_ctes, taken)
104
105
106def _eliminate_derived_table(scope, existing_ctes, taken):
107    # This makes sure that we don't:
108    # - drop the "pivot" arg from a pivoted subquery
109    # - eliminate a lateral correlated subquery
110    if scope.parent.pivots or isinstance(scope.parent.expression, exp.Lateral):
111        return None
112
113    # Get rid of redundant exp.Subquery expressions, i.e. those that are just used as wrappers
114    to_replace = scope.expression.parent.unwrap()
115    name, cte = _new_cte(scope, existing_ctes, taken)
116    table = exp.alias_(exp.table_(name), alias=to_replace.alias or name)
117    table.set("joins", to_replace.args.get("joins"))
118
119    to_replace.replace(table)
120
121    return cte
122
123
124def _eliminate_cte(scope, existing_ctes, taken):
125    parent = scope.expression.parent
126    name, cte = _new_cte(scope, existing_ctes, taken)
127
128    with_ = parent.parent
129    parent.pop()
130    if not with_.expressions:
131        with_.pop()
132
133    # Rename references to this CTE
134    for child_scope in scope.parent.traverse():
135        for table, source in child_scope.selected_sources.values():
136            if source is scope:
137                new_table = exp.alias_(exp.table_(name), alias=table.alias_or_name, copy=False)
138                table.replace(new_table)
139
140    return cte
141
142
143def _new_cte(scope, existing_ctes, taken):
144    """
145    Returns:
146        tuple of (name, cte)
147        where `name` is a new name for this CTE in the root scope and `cte` is a new CTE instance.
148        If this CTE duplicates an existing CTE, `cte` will be None.
149    """
150    duplicate_cte_alias = existing_ctes.get(scope.expression)
151    parent = scope.expression.parent
152    name = parent.alias
153
154    if not name:
155        name = find_new_name(taken=taken, base="cte")
156
157    if duplicate_cte_alias:
158        name = duplicate_cte_alias
159    elif taken.get(name):
160        name = find_new_name(taken=taken, base=name)
161
162    taken[name] = scope
163
164    if not duplicate_cte_alias:
165        existing_ctes[scope.expression] = name
166        cte = exp.CTE(
167            this=scope.expression,
168            alias=exp.TableAlias(this=exp.to_identifier(name)),
169        )
170    else:
171        cte = None
172    return name, cte
def eliminate_subqueries(expression):
 9def eliminate_subqueries(expression):
10    """
11    Rewrite derived tables as CTES, deduplicating if possible.
12
13    Example:
14        >>> import sqlglot
15        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y")
16        >>> eliminate_subqueries(expression).sql()
17        'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y'
18
19    This also deduplicates common subqueries:
20        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y CROSS JOIN (SELECT * FROM x) AS z")
21        >>> eliminate_subqueries(expression).sql()
22        'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y CROSS JOIN y AS z'
23
24    Args:
25        expression (sqlglot.Expression): expression
26    Returns:
27        sqlglot.Expression: expression
28    """
29    if isinstance(expression, exp.Subquery):
30        # It's possible to have subqueries at the root, e.g. (SELECT * FROM x) LIMIT 1
31        eliminate_subqueries(expression.this)
32        return expression
33
34    root = build_scope(expression)
35
36    if not root:
37        return expression
38
39    # Map of alias->Scope|Table
40    # These are all aliases that are already used in the expression.
41    # We don't want to create new CTEs that conflict with these names.
42    taken = {}
43
44    # All CTE aliases in the root scope are taken
45    for scope in root.cte_scopes:
46        taken[scope.expression.parent.alias] = scope
47
48    # All table names are taken
49    for scope in root.traverse():
50        taken.update(
51            {
52                source.name: source
53                for _, source in scope.sources.items()
54                if isinstance(source, exp.Table)
55            }
56        )
57
58    # Map of Expression->alias
59    # Existing CTES in the root expression. We'll use this for deduplication.
60    existing_ctes = {}
61
62    with_ = root.expression.args.get("with")
63    recursive = False
64    if with_:
65        recursive = with_.args.get("recursive")
66        for cte in with_.expressions:
67            existing_ctes[cte.this] = cte.alias
68    new_ctes = []
69
70    # We're adding more CTEs, but we want to maintain the DAG order.
71    # Derived tables within an existing CTE need to come before the existing CTE.
72    for cte_scope in root.cte_scopes:
73        # Append all the new CTEs from this existing CTE
74        for scope in cte_scope.traverse():
75            if scope is cte_scope:
76                # Don't try to eliminate this CTE itself
77                continue
78            new_cte = _eliminate(scope, existing_ctes, taken)
79            if new_cte:
80                new_ctes.append(new_cte)
81
82        # Append the existing CTE itself
83        new_ctes.append(cte_scope.expression.parent)
84
85    # Now append the rest
86    for scope in itertools.chain(root.union_scopes, root.subquery_scopes, root.table_scopes):
87        for child_scope in scope.traverse():
88            new_cte = _eliminate(child_scope, existing_ctes, taken)
89            if new_cte:
90                new_ctes.append(new_cte)
91
92    if new_ctes:
93        query = expression.expression if isinstance(expression, exp.DDL) else expression
94        query.set("with", exp.With(expressions=new_ctes, recursive=recursive))
95
96    return expression

Rewrite derived tables as CTES, deduplicating if possible.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y")
>>> eliminate_subqueries(expression).sql()
'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y'
This also deduplicates common subqueries:
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y CROSS JOIN (SELECT * FROM x) AS z")
>>> eliminate_subqueries(expression).sql()
'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y CROSS JOIN y AS z'
Arguments:
  • expression (sqlglot.Expression): expression
Returns:

sqlglot.Expression: expression